前言:學習tensorflow和深度學習有一段時間了,一直停留在運作别人的代碼和跑mnsit和cifar10資料集上,決定從簡單的動漫頭像生成着手代碼,經過無數的debug後終于完成大概,此間主要參考的有以下兩個代碼,一個是别人寫的DCGAN動漫頭像生成,另一個是pix2pix的tensorflow實作代碼。
動漫頭像生成:https://blog.csdn.net/sinat_33741547/article/details/77871170?locationNum=5&fps=1阿城
pix2pix代碼:https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py
說明:本部分是資料是資料處理部分,采用的資料是别人提取好的動漫頭像,共50000多張,将這些圖檔轉化為tensorflow官方的标準資料TFrecord格式,這個格式的在tensorflow處理的時侯讀取速度會快不少
資料來源
百度網盤 密碼:g5qa
代碼
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
'''
讀取圖檔資料并轉化為tensorflow官方的TFrecord格式
'''
import tensorflow as tf
import os
import sys
import time
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def get_TF():
train_dir = "./faces/" #定義讀取圖檔的路徑
data = []
for file in os.listdir(train_dir): #将圖檔的路徑存儲到data list中
data.append(train_dir+file)
stdi,stdo,stde=sys.stdin,sys.stdout,sys.stderr #如果沒有這部分會提示編碼錯誤
reload(sys) #python3的reload在其他包中
sys.setdefaultencoding('utf-8')
sys.stdin,sys.stdout,sys.stderr=stdi,stdo,stde #改正reload之後print輸出不了的問題
sess=tf.Session()
file_at = 0
start_time = time.time()
for i in range(len(data)):
image_path = data[i] #枚舉每個圖檔的路徑
image_raw_data = tf.gfile.FastGFile(image_path,'r').read()
img_data = tf.image.decode_jpeg(image_raw_data,channels=3) #将讀取到的圖檔按照jpeg的格式解壓成tensor的形式
img_data = img_data.eval(session=sess)
image_raw = img_data.tobytes() #将圖檔的tensor變成字元串
example = tf.train.Example(features=tf.train.Features(feature={ #構造TFrecord形式的example
'height':_int64_feature(img_data.shape[0]),
'width':_int64_feature(img_data.shape[1]),
'channel':_int64_feature(img_data.shape[2]),
'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])) #之後需要的隻有'image_raw',其他可以不定義
}))
if i % 500 == 0: #500個example存儲為一個TFrecord檔案
file_at += 1
filename = ("./TFrecord/data-tfrecords-%.5d" % file_at)
if i>0:
writer.close()
writer = tf.python_io.TFRecordWriter(filename)
print("%d steps,using time %f" % (i,time.time()-start_time))
start_time =time.time()
writer.write(example.SerializeToString()) #将examples寫入TFrecord檔案
writer.close()
get_TF()
在程式實際運作的時候,一開始處理很快,但是後來生成一個TFrecord檔案就越運作越慢,查了資料沒發現其他人有出現這個問題,沒有解決。當然,也可以直接讀取原圖檔訓練。