天天看點

基于Tensorflow和DCGAN生成動漫頭像實踐(一)

前言:學習tensorflow和深度學習有一段時間了,一直停留在運作别人的代碼和跑mnsit和cifar10資料集上,決定從簡單的動漫頭像生成着手代碼,經過無數的debug後終于完成大概,此間主要參考的有以下兩個代碼,一個是别人寫的DCGAN動漫頭像生成,另一個是pix2pix的tensorflow實作代碼。

動漫頭像生成:https://blog.csdn.net/sinat_33741547/article/details/77871170?locationNum=5&fps=1阿城

pix2pix代碼:https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py

說明:本部分是資料是資料處理部分,采用的資料是别人提取好的動漫頭像,共50000多張,将這些圖檔轉化為tensorflow官方的标準資料TFrecord格式,這個格式的在tensorflow處理的時侯讀取速度會快不少

資料來源

百度網盤  密碼:g5qa

代碼

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
'''
     讀取圖檔資料并轉化為tensorflow官方的TFrecord格式
'''
import tensorflow as tf
import os
import sys
import time


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def get_TF():
    train_dir = "./faces/"    #定義讀取圖檔的路徑
    data  = []  
    for file in os.listdir(train_dir):   #将圖檔的路徑存儲到data list中
        data.append(train_dir+file)

    
    stdi,stdo,stde=sys.stdin,sys.stdout,sys.stderr  #如果沒有這部分會提示編碼錯誤
    reload(sys)                                     #python3的reload在其他包中
    sys.setdefaultencoding('utf-8') 
    sys.stdin,sys.stdout,sys.stderr=stdi,stdo,stde  #改正reload之後print輸出不了的問題
  
    sess=tf.Session()
    file_at = 0
    start_time = time.time()
    for i in range(len(data)):
        
        image_path = data[i]    #枚舉每個圖檔的路徑
        image_raw_data = tf.gfile.FastGFile(image_path,'r').read()
        img_data = tf.image.decode_jpeg(image_raw_data,channels=3)    #将讀取到的圖檔按照jpeg的格式解壓成tensor的形式            
        img_data = img_data.eval(session=sess)
        image_raw = img_data.tobytes()    #将圖檔的tensor變成字元串

        example = tf.train.Example(features=tf.train.Features(feature={    #構造TFrecord形式的example
                'height':_int64_feature(img_data.shape[0]),
                'width':_int64_feature(img_data.shape[1]),
                'channel':_int64_feature(img_data.shape[2]),
                'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))   #之後需要的隻有'image_raw',其他可以不定義
                }))
        if i % 500 == 0:   #500個example存儲為一個TFrecord檔案
            file_at += 1
            filename = ("./TFrecord/data-tfrecords-%.5d" % file_at)
            if i>0:
                writer.close()
            writer = tf.python_io.TFRecordWriter(filename)
            print("%d steps,using time %f" % (i,time.time()-start_time))
            start_time =time.time()
        writer.write(example.SerializeToString()) #将examples寫入TFrecord檔案

 

    writer.close()
   
    
    
get_TF()
           

在程式實際運作的時候,一開始處理很快,但是後來生成一個TFrecord檔案就越運作越慢,查了資料沒發現其他人有出現這個問題,沒有解決。當然,也可以直接讀取原圖檔訓練。

繼續閱讀