天天看點

《深度學習之PyTorch實戰計算機視覺》學習筆記(13)

這部分是利用pytorch 進行實戰,利用自動編碼器來實作手寫字型的降噪問題

代碼基于python3.7, pytorch 1.0,cuda 10.0 .

PyTorch之自動編碼實戰(卷積神經網絡模型)¶

所謂的自動編碼器通俗點講就是通過線性模型或卷積模型将具有噪聲的圖像輸入進行提取特征,然後通過相同的操作進行解碼還原,這就是編碼解碼的過程和思想。 這部分實作的是利用自動編碼器模型解決的是一個去除圖檔馬賽克的問題,基于卷積神經網絡模型的神經網絡。

import torch
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
           
# 資料預處理
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean = [0.5],std = [0.5])]) # 注意到MNIST資料集的圖像是灰階圖像,單通道
# 資料讀取
dataset_train = datasets.MNIST(root = './data',
                              transform = transform,
                              train = True,
                              download = False)
dataset_test = datasets.MNIST(root = './data',
                             transform = transform,
                             train = False)
# 資料載入
train_load = torch.utils.data.DataLoader(dataset = dataset_train,batch_size = 64,shuffle = True)
test_load = torch.utils.data.DataLoader(dataset = dataset_test,batch_size = 64,shuffle = True)
           
# 資料可視化
images, label = next(iter(train_load))
print(images.shape)
images_example = torchvision.utils.make_grid(images)
images_example = images_example.numpy().transpose(1,2,0)
mean = 0.5
std = 0.5
images_example = images_example * std + mean
plt.imshow(images_example)
plt.show()
# 給圖像加噪聲
noisy_images = images_example + 0.5 * np.random.randn(*images_example.shape) # 這裡要加一個* ?? 不然會報元組不能加到整形的錯誤
noisy_images = np.clip(noisy_images,0.,1)  # 由于原始的MNSIT的資料集圖像的像素範圍是(0,1),是以加噪後要轉回(0,1)
plt.imshow(noisy_images)
plt.show()
           
torch.Size([64, 1, 28, 28])
           
《深度學習之PyTorch實戰計算機視覺》學習筆記(13)
《深度學習之PyTorch實戰計算機視覺》學習筆記(13)
# 搭建CNN模型
class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super(AutoEncoder,self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1,64,kernel_size = 3,stride = 1,padding = 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size = 2,stride = 2),
            torch.nn.Conv2d(64,128,kernel_size = 3,stride = 1,padding = 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size = 2,stride = 2))
        self.decoder = torch.nn.Sequential(
            torch.nn.Upsample(scale_factor = 2, mode = 'nearest'),
            torch.nn.Conv2d(128,64,kernel_size = 3,stride = 1,padding = 1),
            torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor = 2, mode = 'nearest'),
            torch.nn.Conv2d(64,1,kernel_size = 3,stride = 1,padding = 1))
    def forward(self,input):
        output = self.encoder(input)
        output = self.decoder(output)
        return output
    
model = AutoEncoder()

Use_gpu = torch.cuda.is_available()
if Use_gpu:
    model = model.cuda()
    
print(model)     
           
AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Sequential(
    (0): Upsample(scale_factor=2, mode=nearest)
    (1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Upsample(scale_factor=2, mode=nearest)
    (4): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)
           

在以上代碼中出現torch.nn.Upsample類。這個類就是上采樣,作用就是對我們提取到的核心特征進行解壓,實作圖檔的重寫建構,傳遞給它的參數一共有兩個,分别是scale_factor和mode:前者用于确定解壓的倍數;後者用于定義圖檔重構的模式,可選擇的模式有nearest、linear、bilinear和trilinear,其中nearest是最鄰近法,linear是線性插值法,bilinear是雙線性插值法,trilinear是三線性插值法

# 設定優化器和損失函數
optimizer = torch.optim.Adam(model.parameters())
loss_f = torch.nn.MSELoss()
           
# 訓練網絡
epoch_n = 5
for epoch in range(epoch_n):
    running_loss = 0.0
    
    print('Epoch {}/{}'.format(epoch,epoch_n-1))
    print('-' * 10)
    
    for data in train_load:
        X_train, _ = data
        noisy_X_train = X_train + 0.5 * torch.randn(*X_train.shape)
        noisy_X_train = torch.clamp(noisy_X_train, 0., 1.)
        X_train,noisy_X_train = Variable(X_train.cuda()), Variable(noisy_X_train.cuda())
        train_pre = model(noisy_X_train)
        loss = loss_f(train_pre,X_train)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.data.item()
        
    print('Loss is :{:.4f}'.format(running_loss/len(dataset_train)))
           
Epoch 0/4
----------


e:\project3.7\lib\site-packages\torch\nn\modules\upsampling.py:129: UserWarning: nn.Upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.{} is deprecated. Use nn.functional.interpolate instead.".format(self.name))


Loss is :0.0009
Epoch 1/4
----------
Loss is :0.0005
Epoch 2/4
----------
Loss is :0.0004
Epoch 3/4
----------
Loss is :0.0004
Epoch 4/4
----------
Loss is :0.0004
           
# 驗證結果如何
data_loader_test = torch.utils.data.DataLoader(dataset = dataset_test,
                                              batch_size = 4,
                                              shuffle = True)
X_test,_ = next(iter(data_loader_test))

img1 = torchvision.utils.make_grid(X_test)
img1 = img1.numpy().transpose(1,2,0)
std = 0.5
mean = 0.5
img1 = img1 * std + mean
noisy_X_test = img1 + 0.5 * np.random.rand(*img1.shape)
noisy_X_test = np.clip(noisy_X_test,0.,1.)

plt.figure()
plt.imshow(noisy_X_test)

img2 = X_test + 0.5 * torch.randn(*X_test.shape)
img2 = torch.clamp(img2,0.,1.)

img2 = Variable(img2.cuda())

test_pred = model(img2)

img_test = test_pred.data.view(-1,1,28,28)
img2 = torchvision.utils.make_grid(img_test)
img2 = img2.cpu()   # 這裡要将在cuda()上的tensor資料轉到cpu上,否則會報錯,無法從tensor.cuda()轉為numpy
img2 = img2.numpy().transpose(1,2,0)
img2 = img2 * std + mean
img2 = np.clip(img2,0.,1.)
plt.figure()
plt.imshow(img2)
           
<matplotlib.image.AxesImage at 0x1de68965320>
           
《深度學習之PyTorch實戰計算機視覺》學習筆記(13)
《深度學習之PyTorch實戰計算機視覺》學習筆記(13)

繼續閱讀