天天看點

多光譜遙感分類(一):資料集制作

目錄

多光譜遙感分類(一):資料集制作;本文。

​​​多光譜遙感分類(二):VGG微調​​​​多光譜遙感分類(三):CNN提取特征+RF分類​​​​多光譜遙感分類(四):使用GLCM+RF​​​​多光譜遙感分類(五):代碼優化+自定義模型​​

也可參考:​​遙感分類的一種采樣方法​​ 。

描述

代碼源于很久以前練手的一個Demo,時間長了許多魔改版的都不見了,目前隻剩下此簡陋版本。讀者如有相關需求,可根據隻言片語斷章取義。由于代碼混亂基礎,不再上傳GitHub。

所用資料為多光譜遙感影像(.tif,由arcgis導出RGB彩色圖像),摳圖所得點檔案(.shp)(由摳圖面檔案使用arcgis随機生成點生成,至少有一個字段,即标簽)。

工具篇

根據點shp檔案(樣本點集合),對栅格圖像的3、2、1波段切圖,并儲存在相應标簽下的檔案夾,注意shp、tif的投影坐标一緻

from osgeo import gdal
import numpy as np
import shapefile
import cv2
import os

size=64
bands=3

dataset = gdal.Open(r"E:\資料2\test_tif_peizhun_subset_proj_.tif")
rer=shapefile.Reader(r'E:\shps\test.shp')

def __createDir(path):
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except:
            print("建立檔案夾失敗")
            exit(1)

def __getACell(geo,pos):
    try:
        xoffset = int((pos[0] - geo[0]) / geo[1])
        yoffset = int((pos[1] - geo[3]) / geo[5])

        print("pixels: x= %d,y= %d" % (xoffset, yoffset))
        output = []
        for i in [3,2,1]:
            band = dataset.GetRasterBand(i)
            if (int(xoffset - size / 2) < 0 or int(yoffset - size / 2) < 0
                    or int(xoffset - size / 2) + size > dataset.RasterXSize
                    or int(yoffset - size / 2) + size > dataset.RasterYSize):
                return None
            t = band.ReadAsArray(int(xoffset - size / 2), int(yoffset - size / 2), size, size)
            output.append(t)
        img = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
    except:
        return None
    return img

def getShpDataForNum():

    labels=[i[0] for i in rer.records()]
    for i in set(labels):
        __createDir(os.path.join("data/org/"+str(i)))

    for i in range(rer.numRecords):#rer.numRecords
        print("deal %d: " % (i+1))
        sr=rer.shape(i)
        img=__getACell(dataset.GetGeoTransform(), sr.points[0])
        if(img is None):
            print("the area of points %d is out range." %(i))
            continue
        label=labels[i]
        cv2.imwrite("data/org/%s/%s.%d.jpg" % (label, label, i), img)
        print("data/org/%s/%s.%d.jpg" % (label, label, i))
    print("deal finish,to numpy array.")

getShpDataForNum()      

如下,将上述所得檔案拆分為測試集和訓練集。

import os
import shutil
import random

def createDir(path):
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except:
            print("建立檔案夾失敗")
            exit(1)

createDir("data/train/")
createDir("data/test/")


dir='data/org/'
for dir_item in os.listdir(dir):

    createDir("data/train/" + dir_item)
    createDir("data/test/"+dir_item)

    org_data=os.listdir(dir+dir_item+"/")
    random.shuffle(org_data)
    num=int(len(org_data)*0.25)

    print(dir + dir_item + " start.")
    for d in org_data[:-num]:
        shutil.copyfile(dir + dir_item + "/" + d, "data/train/" + dir_item + "/" + d)
    for d in org_data[-num:]:
        shutil.copyfile(dir+dir_item+"/"+d,"data/test/"+dir_item+"/"+d)
    print(dir+dir_item+" finished")      
import os
import seaborn as sns
import matplotlib.pyplot as plt
def show(path,title):
    d=os.listdir(path)
    d_len=[len(os.listdir(os.path.join(path,i))) for i in d]

    # print(d,d_len)

    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文标簽
    plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負号
    sns.barplot(d,d_len,)
    plt.xlabel("樣本類型")
    plt.ylabel("數量")
    plt.title(title)

    for i in range(len(d_len)):
        plt.text(i,d_len[i]+2,"%d" % d_len[i],ha="center",va="bottom")
    plt.show()

show(r"data/1_train","訓練集源資料采樣集")      
from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import cv2,shutil


class Tiff:
    def createDir(self, path):
        if not os.path.exists(path):
            try:
                os.makedirs(path)
            except:
                print("建立檔案夾失敗")
                exit(1)

    def __init__(self,  pos_src,other_feather,contact_src,size=128,bands=[3,2,1],tif_src=r"D:/lishihang/jiangxia_simple/ZY3_GS_jiangxia1.tif"):

        self.dataset = gdal.Open(tif_src)  # tif資料
        self.size = size  # 采樣視窗大小
        self.bands=bands
        self.contact_pos_feather(pos_src, other_feather,contact_src)
        self.fea =pd.read_csv(contact_src, header=None)
        # shutil.rmtree("data/temp.txt")

    def get_cell(self, pos_x, pos_y):
        try:
            output = []
            for i in self.bands:
                band = self.dataset.GetRasterBand(i)
                t = band.ReadAsArray(int(pos_x - self.size / 2), int(pos_y - self.size / 2), self.size, self.size)
                output.append(t)

            img2 = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
            # print(img2.shape)
            # self.showImg(img2)
        except:
            return None
        return img2

    def get_cells(self,target_src):
        fea_len=len(self.fea)


        self.createDir(target_src)
        for label in set(self.fea.iloc[:,-2]):
            self.createDir("%s/%s" % (target_src,label))

        print("fea length: %d" % fea_len)

        for i in range(fea_len):
            temp=self.fea.iloc[i,:].values
            img = self.get_cell(temp[1], temp[0])
            if img is None:
                continue
            cv2.imwrite("%s/%s/%s.%d.jpg" % (target_src,temp[-2], temp[-2], i), img)
            if(i%1000==0):
                print("%d/%d hava finsh save." % (i,fea_len))

    def contact_pos_feather(self,pos_src, other_feather,target):
        if os.path.exists(target):
            print("檔案已存在")
            return
        pos = pd.read_csv(pos_src, header=None, sep=' ')
        feather = pd.read_csv(other_feather, header=None, sep='\t')
        # fea = pd.concat([pos, feather], axis=1).sample(frac=1).reset_index(drop=True)
        fea = pd.concat([pos, feather], axis=1)
        print("pos Length=%d,feather Length=%d,fea Length=%d" % (len(pos), len(feather), len(fea)))
        # print(type(fea))
        del feather
        del pos
        fea = pd.DataFrame(fea)
        fea.to_csv(target, index=None, header=None)



if __name__ == '__main__':
    tiff=Tiff(r"D:/tr_sample_1.txt",r"D:/train1.txt",r"tr_1.txt")
    # tiff=Tiff(r"D:/te_sample_1.txt",r"D:/test1.txt",r"te_1.txt")
    # tiff.get_cells("data/1_test")