天天看點

audoencoder自編碼練習

1.自編碼介紹

  • 自編碼是一種資料壓縮算法,類似于主成分分析法
  • 特性
    • 自編碼與資料相關,你的自編器是通過訓練才能使用的,如果你使用手寫數字作為訓練集,那麼編碼器在壓縮手寫數字是非常好的,對于其他資料是非常不好的。
    • 自動編碼器是有損的,即輸入和輸出的資料是有一些不同的
    • 自動編碼器是從資料樣本中自動學習的
      audoencoder自編碼練習

2.實戰壓縮手寫數字

2.1 導入我們的庫

import torch    
import torch.nn as nn    
import torch.utils.data as Data    
import torchvision    
import matplotlib.pyplot as plt    
import numpy as np    
# 随機種子
torch.manual_seed(1)    
           

2.2 定義超參數

# Hyper Parameters    
EPOCH = 2    
BATCH_SIZE = 64    
LR = 0.01    
DOWNLOAD_MNIST = True    
N_TEST_IMG = 5   
           

2.3 加載資料集,并展示其中一個資料

# 導入資料    
train_data = torchvision.datasets.MNIST(    
		root='./mnist/',    
        train=True,    
        transform=torchvision.transforms.ToTensor(),    
        download=DOWNLOAD_MNIST,    
        )    
        
# plot one example    
print(train_data.train_data.size())     # (60000, 28, 28)    
print(train_data.train_labels.size())   # (60000)    
plt.imshow(train_data.train_data[3].numpy(),cmap='gray')    
plt.title('%i' % train_data.train_labels[3])    
plt.show()  
# 打包資料
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE,shuffle=True)
           

2.4 定義我們的AutoEncoder

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder,self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28*28,128),
            nn.Tanh(),
            nn.Linear(128,64),
            nn.Tanh(),
            nn.Linear(64,12),
            nn.Tanh(),
            nn.Linear(12,3),    # 壓縮成三個神經元
            )
        self.decoder = nn.Sequential(
            nn.Linear(3,12),
            nn.Tanh(),
            nn.Linear(12,64),
            nn.Tanh(),
            nn.Linear(64,128),
            nn.Tanh(),
            nn.Linear(128,28*28),
            nn.Sigmoid()    # 映射到0-1範圍内
            )
˽˽˽˽
    def forward(self,x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

autoencoder = AutoEncoder()
optimizer = torch.optim.Adam(autoencoder.parameters(),lr=LR)                                             
loss_func = nn.MSELoss()
           

2.5 初始化圖檔

# 初始化照片
f ,a = plt.subplots(2,N_TEST_IMG,figsize=(15,12))
plt.ion()       # 開啟動圖                                                                               

# 畫初始的5個圖
view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255
for i in range(N_TEST_IMG):
    a[0][i].imshow(np.reshape(view_data.numpy()[i],(28,28)),cmap='gray')
    a[0][i].set_xticks(())
           

2.6 訓練我們的AutoEncoder

for epoch in range(EPOCH):
    for step, (x, b_label) in enumerate (train_loader):
        b_x = x.view(-1,28*28)  # batch x,shape (batch,28*28)
        b_y = x.view(-1,28*28)  # batch x,shape (batch,28*28)
        
        encoded, decoded = autoencoder(b_x)

        loss = loss_func(decoded,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print('Epoch:',epoch,'| train loss: %.4f' % loss.data.numpy())

            # 看看測試的情況   解碼後的圖檔
            _, decoded_data = autoencoder(view_data)
            for i in range(N_TEST_IMG):
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i],(28,28)),cmap='gray')    # 展示我們壓縮在解碼的圖檔
                a[1][i].set_xticks(())
            plt.draw();plt.pause(0.05)

plt.ioff()
plt.show()
           

2.7 展示我們的壓縮後的資料

# 畫一個3D的圖,來展示  壓縮後的精髓資料
view_data = train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.
encoded_data, _ = autoencoder(view_data)
# 3D圖的初始化
fig = plt.figure(2); ax = Axes3D(fig)
# 提取X,Y,Z
X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
# 制定标簽
values = train_data.train_labels[:200].numpy()
for x, y, z, s in zip(X, Y, Z, values):
    # 開始畫圖,并加上注釋
    # 彩色映射(難了解)
    c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)
# 設定坐标軸的刻度
ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max())
plt.show()
           

可以使Arch的neovim使用系統剪切版

sudo pacman -S xsel

繼續閱讀