天天看点

Deep Auto-encoder的代码实现

李宏毅讲Auto-encoder视频:链接地址

看了上面的Auto-encoder视频,想试着做一下里面的Deep Auto-encoder,看看效果如何,神经网络架构如下:

Deep Auto-encoder的代码实现

按照上面的网络架构,采用Keras实现Deep Auto-encoder,loss函数采用均方误差函数。

迭代之后的loss下降图:

Deep Auto-encoder的代码实现

最终的效果:

上面是原图,下面是由原图经过整个网络生成的图片:

Deep Auto-encoder的代码实现

效果并没有论文中显示的那么好,暂时还没找到原因,明天看看论文,看看能不能解决。

更新:论文上面说如果要训练得很好,那么需要将参数初始化到最优附近,然后通过反向传播算法进行fine-tune可以使得结果很好,初始化感觉比较复杂,难搞。

代码如下:

#coding=utf-8

from keras.datasets import mnist
from keras.layers import Input,Dense,Reshape
from keras.models import Sequential, Model
from keras.optimizers import Adam,SGD

import matplotlib.pyplot as plt
import sys
import os
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

img_height = 28
img_width = 28
def build_model():

    encoder = Sequential()
    encoder.add(Dense(units=1000))
    encoder.add(Dense(units=500))
    encoder.add(Dense(units=250))
    encoder.add(Dense(units=30))

    decoder = Sequential()
    decoder.add(Dense(units=250))
    decoder.add(Dense(units=500))
    decoder.add(Dense(units=1000))
    decoder.add(Dense(units=784))

    img_input = Input(shape=[img_width*img_height])
    code = encoder(img_input)

    reconstruct_img = decoder(code)

    combined = Model(img_input,reconstruct_img)

    optimizer = Adam(0.001)
    combined.compile(loss='mse', optimizer=optimizer)
    return encoder,decoder,combined


epochs = 100000
batch_size = 64
mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True)
def train():
    losses = []
    for i in range(epochs):
        imgs,labels = mnist.train.next_batch(batch_size)
        imgs = imgs/2.0 - 1 # -1.0 - 1.0
        loss = combined.train_on_batch(imgs,imgs)
        if i % 5 == 0:
            print("epoch:%d,loss:%f"%(i,loss))
            losses.append(loss)
    plt.plot(np.arange(0,epochs,5),losses)
    
def test():
    mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True)
    imgs,labels = mnist.test.next_batch(3)
    imgs = imgs*2.0 - 1 # -1.0 - 1.0
    output_imgs = combined.predict(imgs)
    output_imgs = (output_imgs+1)/2.0 # -1.0 - 1.0
    for i in range(3):
        plt.figure(1)
        plt.subplot(2,3,i+1) #两行一列的第一个子图
        plt.imshow(imgs[i].reshape((28,28)), cmap='gray')
        plt.subplot(2,3,i+1+3) #两行一列的第二个子图
        plt.imshow(output_imgs[i].reshape((28,28)), cmap='gray')
        
if __name__ == '__main__':
    encoder, decoder, combined = build_model()
    train()
    test()
           

继续阅读