天天看點

Tensorflow資料集制作專題【四】— 将圖檔檔案制作成TFRecord資料集,并從TFRecord檔案集讀取資料

最近由于工作需要, 需要迅速掌握tensorflow這一塊的基礎知識, 今天就講一講如何将圖檔轉化為tfrecord檔案進行資料的儲存, 然後從tfrecord資料檔案中再加載出原圖檔。

這一塊不涉及過多的理論知識, 需要熟練掌握tensorflow的API即可, 最好的辦法就是努力學習。部落格裡傳的每一份代碼都是自己在本地Pycharm裡面調試過的,希望對你也有幫助。

我的環境是 python3.6 + tensorflow-gpu, 因為tensorflow-gpu版本處理圖檔資料比較塊, 是以我選擇了gpu版本,大家可以根據自己的實際情況選擇相應的cpu或者gpu版本。

ok,直接上代碼:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ProjectName : 04_generate_tfrecord.py
# @DateTime :  2019-11-24 15:10
# @Author : 皮皮蝦

import os
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import tensorflow as tf
from sklearn.utils import shuffle


def load_image(src_path):
    image_path_list = []
    real_label_list = []
    for _dir_ in os.listdir(src_path):
        image_dirname_path = os.path.join(src_path, _dir_)
        for image in os.listdir(image_dirname_path):
            image_path = os.path.join(image_dirname_path, image)
            image_path_list.append(image_path)
            real_label_list.append(_dir_)
    # 将原始的label進行轉化,{"man": 0, "woman": 1}
    map_label_list = []
    for label in real_label_list:
        if label == "man":
            map_label_list.append(0)
        else:
            map_label_list.append(1)
    return shuffle(np.asarray(image_path_list), np.asarray(map_label_list))


def create_tfrecord(filenames, labels, save_path):
    # 建構writer, 向檔案中寫入資料
    writer = tf.python_io.TFRecordWriter(path=save_path)
    for index in tqdm(range(0, len(labels))):
        img = Image.open(fp=filenames[index], mode="r")
        img = img.resize(size=(256, 256))
        # 将圖檔轉化為二進制
        img_raw = img.tobytes()
        # 構造example協定塊, 封裝label和image
        example = tf.train.Example(features=tf.train.Features(feature={
            # 存儲标簽
            "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[index]])),
            # 存儲圖檔
            "img_raw":tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
        # 序列化為字元串
        writer.write(record=example.SerializeToString())
    writer.close()


def read_and_decode_tfrecord_file(filenames, flag="train", batch_size=3):
    # 根據檔案名生成一個隊列
    if flag == "train":
        # 預設shuffle=True
        filename_queue = tf.train.string_input_producer(string_tensor=[filenames])
    else:
        filename_queue = tf.train.string_input_producer(string_tensor=[filenames], num_epochs=1, shuffle=False)
    # 構造閱讀器
    reader = tf.TFRecordReader()
    # 讀取檔案 _ 傳回的是檔案名
    _, serialize_example = reader.read(queue=filename_queue)
    # 取出包含image和label的feature
    features = tf.parse_single_example(serialized=serialize_example,
                                       features={
                                           "label": tf.FixedLenFeature([], tf.int64),
                                           "img_raw": tf.FixedLenFeature([], tf.string)
                                       })
    # 調用tf.decode_raw将字元串解析成圖像對應的像素數組
    image = tf.decode_raw(bytes=features["img_raw"], out_type=tf.uint8)
    image = tf.reshape(tensor=image, shape=[256, 256, 3])
    # 轉換标簽的類型
    label = tf.cast(x=features["label"], dtype=tf.int32)
    # 如果是訓練使用, 則應該将其歸一化,并按batch組合
    if flag == "train":
        # 歸一化
        image = tf.cast(x=image, dtype=tf.float32) * (1./255) - 0.5
        # 生成批次資料
        image_batch,label_batch = tf.train.batch(tensors=[image, label],
                                                 batch_size=batch_size,
                                                 capacity=20)
        return images_path, label_batch

    return image, label


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_path",
        default=r" ",
        type=str,
        required=False,
        help="image input path"
    )
    parser.add_argument(
        "--save_tfrecord_path",
        default=r" ",
        type=str,
        required=False,
        help="image input path"
    )
    parser.add_argument(
        "--save_iamge_path",
        default=r" ",
        type=str,
        required=False,
        help="image input path"
    )
    FLAGS, _ = parser.parse_known_args()

    images_path, labels = load_image(src_path=FLAGS.input_path)
    create_tfrecord(filenames=images_path, labels=labels, save_path=FLAGS.save_tfrecord_path)
    image, label = read_and_decode_tfrecord_file(filenames=FLAGS.save_tfrecord_path, flag="test")

    if tf.gfile.Exists(filename=FLAGS.save_iamge_path):
        tf.gfile.DeleteRecursively(dirname=FLAGS.save_iamge_path)
    tf.gfile.MakeDirs(dirname=FLAGS.save_iamge_path)
    # 建立session, 開啟會話
    with tf.Session() as sess:
        # 初始化本地變量
        local_op = tf.local_variables_initializer()
        sess.run(local_op)
        # 建立一個線程協調器
        coord = tf.train.Coordinator()
        # 開啟多線程
        threads = tf.train.start_queue_runners(coord=coord)
        # 建立集合, 存放子檔案夾
        my_set= set([])
        try:
            i = 1
            while True:
                # 取出image和label
                _enxmaple_image_, _example_label_ = sess.run(fetches=[image,label])
                _example_label_ = str(_example_label_)
                if _example_label_ not in my_set:
                    my_set.add(_example_label_)
                    # 建立子檔案夾
                    tf.gfile.MakeDirs(dirname=os.path.join(FLAGS.save_iamge_path, _example_label_))
                # 轉換圖檔格式
                _image_ = Image.fromarray(obj=_enxmaple_image_, mode="RGB")
                # 儲存圖檔
                if _example_label_ == "0":
                    _image_.save(os.path.join(FLAGS.save_iamge_path + _example_label_,  "man_" + str(i) + ".jpg"))
                else:
                    _image_.save(os.path.join(FLAGS.save_iamge_path + _example_label_, "woman_" + str(i) + ".jpg"))
                i = i + 1
        except tf.errors.OutOfRangeError:
            print("Done Test -- epoch limit reaches")
        finally:
            coord.request_stop()
            coord.join(threads=threads)
           

以上就是完整的将圖檔資料轉化為tfrecord檔案, 然後從tfrecord檔案中讀取資料的過程, 建議大家在寫的時候,一定要點進去看看源碼API的參數, 知道要傳參數的類型, 我們才能一步步往下構造整個過程, 否則死記硬背的代碼, 沒有任何用處。

繼續閱讀