天天看點

基于pytorch的神經網絡/卷積自動編碼器源碼

神經網絡的

import torch
import torch.nn as nn
import torch.utils.data as Data
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
import matplotlib.pyplot as plt
import numpy as np

#讀取資料
train_data = MNIST(root='./mnist/',train=True,transform=tfs.ToTensor())#60000張訓練集
print(train_data.train_data.size())     # (60000, 28, 28)
print(train_data.train_labels.size())   # (60000)
plt.imshow(train_data.train_data[0].numpy())#生成第第1張圖檔,顯示為彩色
plt.show()
train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)#分批并打亂順序

#定義自動編碼器
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(784, 400),
            nn.ReLU(True),
            nn.Linear(400, 200),
            nn.ReLU(True),
            nn.Linear(200, 100),
            nn.ReLU(True),
            nn.Linear(100,3)
        )

        self.decoder = nn.Sequential(
            nn.Linear(3, 100),
            nn.ReLU(True),
            nn.Linear(100, 200),
            nn.ReLU(True),
            nn.Linear(200, 400),
            nn.ReLU(True),
            nn.Linear(400, 784),
            nn.Tanh()
        )
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

autoencoder = Autoencoder()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.005)    #優化方式
loss_func = nn.MSELoss()                                            #損失函數  均方誤差

#建立一個畫布
f, a = plt.subplots(2, 10, figsize=(10, 2)) #初始化數字 在圖表中建立子圖顯示的圖像是2行10列的.figize(長,寬)
plt.ion()
#在互動模式下:plt.plot(x)或plt.imshow(x)是直接出圖像,不需要plt.show()
#如果在腳本中使用ion()指令開啟了互動模式,沒有使用ioff()關閉的話,則圖像會一閃而過,并不會常留。要想防止這種情況,
# 需要在plt.show()之前加上ioff()指令。

# 用于檢視原始資料
view_data = train_data.train_data[:10].view(-1, 28*28).type(torch.Tensor)/255
#print(view_data)
for i in range(10):
    a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)))
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())  #設定位置

#開始訓練
for epoch in range(10):
    for step, (x, b_label) in enumerate(train_loader):   #可同時獲得索引和值
        #print(x.shape)           #64,1,28,28
        b_x = x.view(-1, 28*28)   # batch x, shape (batch, 28*28)
        #print(b_x.shape)         #64*784
        b_y = x.view(-1, 28*28)   # batch y, shape (batch, 28*28)

        encoded, decoded = autoencoder(b_x)

        loss = loss_func(decoded, b_y)      # 計算損失函數
        optimizer.zero_grad()               # 梯度清零
        loss.backward()                     # 反向傳播
        optimizer.step()                    # 梯度優化

        if step % 100 == 0:        #每100步顯示一次
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())

            #繪制解碼圖像
            encoded_data, decoded_data = autoencoder(view_data)
            #print(encoded_data.shape)
            for i in range(10):
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)))
                a[1][i].set_xticks(()); a[1][i].set_yticks(())
            plt.draw(); plt.pause(0.05)#暫停0.05秒

plt.ioff()
plt.show()

           

卷積自動編碼器

import torch
import torch.nn as nn
import torch.utils.data as Data
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
import matplotlib.pyplot as plt
import numpy as np

#讀取資料
train_data = MNIST(root='./mnist/',train=True,transform=tfs.ToTensor())#60000張訓練集
print(train_data.train_data.size())     # (60000, 28, 28)
print(train_data.train_labels.size())   # (60000)
plt.imshow(train_data.train_data[0].numpy())#生成第第三張圖檔,顯示的為彩色圖像
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)#分批并打亂順序

#定義自動編碼器
class conv_autoencoder(nn.Module):
    def __init__(self):
        super(conv_autoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=3, padding=1),  # (b, 16, 10, 10)
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  # (b, 16, 5, 5)
            nn.Conv2d(16, 8, kernel_size=3, stride=2, padding=1),  # (b, 8, 3, 3)
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  # (b, 8, 2, 2)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, kernel_size=3, stride=2),  # (b, 16, 5, 5)
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, kernel_size=5, stride=3, padding=1),  # (b, 8, 15, 15)
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, kernel_size=2, stride=2, padding=1),  # (b, 1, 28, 28)
            nn.Tanh()
        )

    def forward(self, x):
        encode = self.encoder(x)
        decode = self.decoder(encode)
        return encode, decode

autoencoder = conv_autoencoder()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.005)#優化方式
loss_func = nn.MSELoss()                               #損失函數 均方誤差

#開始訓練
for epoch in range(1):
    for step, (x, b_label) in enumerate(train_loader):
        #print(x.shape)           #64,1,28,28
        b_x = x
        b_y = x

        encoded, decoded = autoencoder(b_x)
        loss = loss_func(decoded, b_y)      # 計算損失函數
        optimizer.zero_grad()               # 梯度清零
        loss.backward()                     # 反向傳播
        optimizer.step()                    # 梯度優化

        if step % 100 == 0:        #每100步顯示一次
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())

#建立一個畫布
f, a = plt.subplots(2, 10, figsize=(10, 2)) #初始化數字 在圖表中建立子圖顯示的圖像是2行5列的
plt.ion()
#在互動模式下:plt.plot(x)或plt.imshow(x)是直接出圖像,不需要plt.show()
#如果在腳本中使用ion()指令開啟了互動模式,沒有使用ioff()關閉的話,則圖像會一閃而過,并不會常留。要想防止這種情況,
# 需要在plt.show()之前加上ioff()指令。

# 用于檢視原始資料
view_data = train_data.train_data[:10].view(-1,1,28,28).type(torch.Tensor)/255.
#print(view_data.shape)  10,1,28,28
for i in range(10):
    a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)))
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())

encoded_data, decoded_data = autoencoder(view_data)
for i in range(10):
    a[1][i].clear()
    a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)))
    a[1][i].set_xticks(())
    a[1][i].set_yticks(())
plt.draw()
plt.pause(0.05)  # 暫停0.05秒
plt.ioff()
plt.show()