最近由于工作需要, 需要迅速掌握tensorflow這一塊的基礎知識, 今天就講一講如何将圖檔轉化為tfrecord檔案進行資料的儲存, 然後從tfrecord資料檔案中再加載出原圖檔。
這一塊不涉及過多的理論知識, 需要熟練掌握tensorflow的API即可, 最好的辦法就是努力學習。部落格裡傳的每一份代碼都是自己在本地Pycharm裡面調試過的,希望對你也有幫助。
我的環境是 python3.6 + tensorflow-gpu, 因為tensorflow-gpu版本處理圖檔資料比較塊, 是以我選擇了gpu版本,大家可以根據自己的實際情況選擇相應的cpu或者gpu版本。
ok,直接上代碼:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ProjectName : 04_generate_tfrecord.py
# @DateTime : 2019-11-24 15:10
# @Author : 皮皮蝦
import os
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import tensorflow as tf
from sklearn.utils import shuffle
def load_image(src_path):
image_path_list = []
real_label_list = []
for _dir_ in os.listdir(src_path):
image_dirname_path = os.path.join(src_path, _dir_)
for image in os.listdir(image_dirname_path):
image_path = os.path.join(image_dirname_path, image)
image_path_list.append(image_path)
real_label_list.append(_dir_)
# 将原始的label進行轉化,{"man": 0, "woman": 1}
map_label_list = []
for label in real_label_list:
if label == "man":
map_label_list.append(0)
else:
map_label_list.append(1)
return shuffle(np.asarray(image_path_list), np.asarray(map_label_list))
def create_tfrecord(filenames, labels, save_path):
# 建構writer, 向檔案中寫入資料
writer = tf.python_io.TFRecordWriter(path=save_path)
for index in tqdm(range(0, len(labels))):
img = Image.open(fp=filenames[index], mode="r")
img = img.resize(size=(256, 256))
# 将圖檔轉化為二進制
img_raw = img.tobytes()
# 構造example協定塊, 封裝label和image
example = tf.train.Example(features=tf.train.Features(feature={
# 存儲标簽
"label":tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[index]])),
# 存儲圖檔
"img_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
# 序列化為字元串
writer.write(record=example.SerializeToString())
writer.close()
def read_and_decode_tfrecord_file(filenames, flag="train", batch_size=3):
# 根據檔案名生成一個隊列
if flag == "train":
# 預設shuffle=True
filename_queue = tf.train.string_input_producer(string_tensor=[filenames])
else:
filename_queue = tf.train.string_input_producer(string_tensor=[filenames], num_epochs=1, shuffle=False)
# 構造閱讀器
reader = tf.TFRecordReader()
# 讀取檔案 _ 傳回的是檔案名
_, serialize_example = reader.read(queue=filename_queue)
# 取出包含image和label的feature
features = tf.parse_single_example(serialized=serialize_example,
features={
"label": tf.FixedLenFeature([], tf.int64),
"img_raw": tf.FixedLenFeature([], tf.string)
})
# 調用tf.decode_raw将字元串解析成圖像對應的像素數組
image = tf.decode_raw(bytes=features["img_raw"], out_type=tf.uint8)
image = tf.reshape(tensor=image, shape=[256, 256, 3])
# 轉換标簽的類型
label = tf.cast(x=features["label"], dtype=tf.int32)
# 如果是訓練使用, 則應該将其歸一化,并按batch組合
if flag == "train":
# 歸一化
image = tf.cast(x=image, dtype=tf.float32) * (1./255) - 0.5
# 生成批次資料
image_batch,label_batch = tf.train.batch(tensors=[image, label],
batch_size=batch_size,
capacity=20)
return images_path, label_batch
return image, label
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path",
default=r" ",
type=str,
required=False,
help="image input path"
)
parser.add_argument(
"--save_tfrecord_path",
default=r" ",
type=str,
required=False,
help="image input path"
)
parser.add_argument(
"--save_iamge_path",
default=r" ",
type=str,
required=False,
help="image input path"
)
FLAGS, _ = parser.parse_known_args()
images_path, labels = load_image(src_path=FLAGS.input_path)
create_tfrecord(filenames=images_path, labels=labels, save_path=FLAGS.save_tfrecord_path)
image, label = read_and_decode_tfrecord_file(filenames=FLAGS.save_tfrecord_path, flag="test")
if tf.gfile.Exists(filename=FLAGS.save_iamge_path):
tf.gfile.DeleteRecursively(dirname=FLAGS.save_iamge_path)
tf.gfile.MakeDirs(dirname=FLAGS.save_iamge_path)
# 建立session, 開啟會話
with tf.Session() as sess:
# 初始化本地變量
local_op = tf.local_variables_initializer()
sess.run(local_op)
# 建立一個線程協調器
coord = tf.train.Coordinator()
# 開啟多線程
threads = tf.train.start_queue_runners(coord=coord)
# 建立集合, 存放子檔案夾
my_set= set([])
try:
i = 1
while True:
# 取出image和label
_enxmaple_image_, _example_label_ = sess.run(fetches=[image,label])
_example_label_ = str(_example_label_)
if _example_label_ not in my_set:
my_set.add(_example_label_)
# 建立子檔案夾
tf.gfile.MakeDirs(dirname=os.path.join(FLAGS.save_iamge_path, _example_label_))
# 轉換圖檔格式
_image_ = Image.fromarray(obj=_enxmaple_image_, mode="RGB")
# 儲存圖檔
if _example_label_ == "0":
_image_.save(os.path.join(FLAGS.save_iamge_path + _example_label_, "man_" + str(i) + ".jpg"))
else:
_image_.save(os.path.join(FLAGS.save_iamge_path + _example_label_, "woman_" + str(i) + ".jpg"))
i = i + 1
except tf.errors.OutOfRangeError:
print("Done Test -- epoch limit reaches")
finally:
coord.request_stop()
coord.join(threads=threads)
以上就是完整的将圖檔資料轉化為tfrecord檔案, 然後從tfrecord檔案中讀取資料的過程, 建議大家在寫的時候,一定要點進去看看源碼API的參數, 知道要傳參數的類型, 我們才能一步步往下構造整個過程, 否則死記硬背的代碼, 沒有任何用處。