天天看點

Auto-Encoders實戰

目錄

  • ​​Outline​​
  • ​​Auto-Encoder​​
  • ​​建立編解碼器​​
  • ​​訓練​​

Outline

  • Auto-Encoder
  • Variational Auto-Encoders

建立編解碼器

import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt

tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')


def save_images(imgs, name):
    new_im = Image.new('L', (280, 280))

    index = 0
    for i in range(0, 280, 28):
        for j in range(0, 280, 28):
            im = imgs[index]
            im = Image.fromarray(im, mode='L')
            new_im.paste(im, (i, j))
            index += 1

    new_im.save(name)


h_dim = 20  # 784降維20維
batchsz = 512
lr = 1e-3

(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(
    np.float32) / 255.
# we do not need label
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)


class AE(keras.Model):
    def __init__(self):
        super(AE, self).__init__()

        # Encoders
        self.encoder = Sequential([
            layers.Dense(256, activation=tf.nn.relu),
            layers.Dense(128, activation=tf.nn.relu),
            layers.Dense(h_dim)
        ])

        # Decoders
        self.decoder = Sequential([
            layers.Dense(128, activation=tf.nn.relu),
            layers.Dense(256, activation=tf.nn.relu),
            layers.Dense(784)
        ])

    def call(self, inputs, training=None):
        # [b,784] ==> [b,19]
        h = self.encoder(inputs)

        # [b,10] ==> [b,784]
        x_hat = self.decoder(h)

        return x_hat


model = AE()
model.build(input_shape=(None, 784))  # tensorflow盡量用元組
model.summary()
      
(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)
Model: "ae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential (Sequential)      multiple                  236436    
_________________________________________________________________
sequential_1 (Sequential)    multiple                  237200    
=================================================================
Total params: 473,636
Trainable params: 473,636
Non-trainable params: 0
_________________________________________________________________
      

訓練

optimizer = tf.optimizers.Adam(lr=lr)

for epoch in range(10):

    for step, x in enumerate(train_db):

        # [b,28,28]==>[b,784]
        x = tf.reshape(x, [-1, 784])

        with tf.GradientTape() as tape:
            x_rec_logits = model(x)

            rec_loss = tf.losses.binary_crossentropy(x,
                                                     x_rec_logits,
                                                     from_logits=True)
            rec_loss = tf.reduce_min(rec_loss)

        grads = tape.gradient(rec_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if step % 100 == 0:
            print(epoch, step, float(rec_loss))
            
            # evaluation

        x = next(iter(test_db))
        logits = model(tf.reshape(x, [-1, 784]))
        x_hat = tf.sigmoid(logits)
        # [b,784]==>[b,28,28]
        x_hat = tf.reshape(x_hat, [-1, 28, 28])

        # [b,28,28] ==> [2b,28,28]
        x_concat = tf.concat([x, x_hat], axis=0)
        # x_concat = x  # 原始圖檔
        x_concat = x_hat
        x_concat = x_concat.numpy() * 255.
        x_concat = x_concat.astype(np.uint8)  # 儲存為整型
        if not os.path.exists('ae_images'):
            os.mkdir('ae_images')
        save_images(x_concat, 'ae_images/rec_epoch_%d.png' % epoch)
      
0 0 0.09717604517936707
0 100 0.12493347376585007
1 0 0.09747321903705597
1 100 0.12291513383388519
2 0 0.10048121958971024
2 100 0.12292417883872986
3 0 0.10093794018030167
3 100 0.12260882556438446
4 0 0.10006923228502274
4 100 0.12275046110153198
5 0 0.0993042066693306
5 100 0.12257824838161469
6 0 0.0967678651213646
6 100 0.12443818897008896
7 0 0.0965462476015091
7 100 0.12179268896579742
8 0 0.09197664260864258
8 100 0.12110235542058945
9 0 0.0913471132516861
9 100 0.12342415750026703