天天看點

計算圖像資料集的均值和方差(mean, std)用于transforms.Normalize()标準化

Pytorch圖像預處理時,通常使用transforms.Normalize(mean, std)對圖像按通道進行标準化,即減去均值,再除以方差。這樣做可以加快模型的收斂速度。其中參數mean和std分别表示圖像每個通道的均值和方差序列。

Imagenet資料集的均值和方差為:mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),因為這是在百萬張圖像上計算而得的,是以我們通常見到在訓練過程中使用它們做标準化。而對于特定的資料集,選擇這個值的結果可能并不理想。接下來給出計算特定資料集的均值和方差的方法。

def getStat(train_data):
    '''
    Compute mean and variance for training data
    :param train_data: 自定義類Dataset(或ImageFolder即可)
    :return: (mean, std)
    '''
    print('Compute mean and variance for training data.')
    print(len(train_data))
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'D:\cifar10_images\test', transform=None)
    print(getStat(train_dataset))      

  

getState()方法接收一個Dataset類(ImageFolder),然後累加所有圖像三個通道的均值和方差,最後除以圖像總數并傳回。

這裡用cifar10做的測試,測試集傳回的結果如下所示:

Compute mean and variance for training data.
10000
([0.4940607, 0.4850613, 0.45037037], [0.20085774, 0.19870903, 0.20153421])