天天看點

Image轉tfrecord和tfrecod轉image

本人最近在學習Tensorflow來處理圖檔,對圖像做處理儲存為tfrecord。在轉換時,發現RGB轉成灰階圖再儲存tfrecord檔案,需要花費大量時間。是以仔細研究了下,使用 PIL儲存圖檔,是相當快的,節約上百倍的時間。

下面是我寫的Image轉tfrecord和tfrecod轉image代碼。如對大家有幫助,請Star我的github:

https://github.com/Alex-AI-Du/Tensorflow-Tutorial/blob/master/image_tfrecord/Image2tfrecord2image.py

import os
import tensorflow as tf
from PIL import Image
import glob
from itertools import groupby
from collections import defaultdict
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  #忽略煩人的警告
IMAGE_WIDTH = 256
IMAGE_HEIGHT = 256
IMAGE_CHANNEL = 3
def split_image_dataset(path, training_dataset, testing_dataset = False):
    image_filenames = glob.glob(path)
    image_filename_with_breed = list(map(lambda filename: (filename.split("\\")[-2], filename), image_filenames))
    for category, breed_images in groupby(image_filename_with_breed, lambda x: x[0]):
        for i, breed_image in enumerate(breed_images):
            if i % 5 == 0 and testing_dataset != False:
                testing_dataset[category].append(breed_image[1])
            else:
                training_dataset[category].append(breed_image[1])
        if testing_dataset != False:
            category_training_count = len(training_dataset[category])
            category_testing_count = len(testing_dataset[category])
            category_training_count_float = float(category_training_count)
            category_testing_count_float = float(category_testing_count)
            assert round(category_testing_count_float / (category_training_count_float + category_testing_count_float), 2) > 0.18, "Not enough testing images."
    if testing_dataset != False:
        print("training_dataset testing_dataset END ------------------------------------------------------")
    else:
        print("training_dataset END ------------------------------------------------------")

# 制作TFRecord檔案
def makeTFRecord(dataset, record_location, sess, tfread = False, rows=IMAGE_WIDTH, cols=IMAGE_HEIGHT):
    if not os.path.exists(record_location):
        print("目錄 %s 不存在,自動建立中..." % (record_location))
        os.makedirs(record_location)
    writer = None
    current_index = 0
    for category, images_filenames in dataset.items():
        for image_filename in images_filenames:
            if current_index % 100 == 0:
                if writer:
                    writer.close()
                record_filename = "{record_location}-{current_index}.tfrecords".format(
                    record_location=record_location,
                    current_index=current_index)
                writer = tf.python_io.TFRecordWriter(record_filename)
                print(record_filename + "------------------------------------------------------")
            current_index += 1
            if tfread == True :
                image_file = tf.read_file(image_filename)
                try:
                    image = tf.image.decode_jpeg(image_file)
                except:
                    print(sys._getframe().f_lineno,image_filename)
                    continue
                grayscale_image = tf.image.rgb_to_grayscale(image)
                resized_image = tf.image.resize_images(grayscale_image, [rows, cols])
                image_bytes = sess.run(tf.cast(resized_image, tf.uint8)).tobytes()
            else:
                image = Image.open(image_filename)
                image = image.resize((rows, cols))
                image_bytes = image.tobytes() # 将圖檔轉成二進制
            image_label = category.encode("utf-8")
            example = tf.train.Example(features=tf.train.Features(feature={
                'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_label])),
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
                }))
            writer.write(example.SerializeToString())
    writer.close()
    print("write_records_file testing_dataset training_dataset END------------------------------------------------------")

# 将二進制檔案讀入圖中; rows=makeTFRecord.cols, cols=makeTFRecord.rows
def read_and_decode(filequeuelist, rows=IMAGE_HEIGHT, cols=IMAGE_WIDTH):
    fileName_Queue = tf.train.string_input_producer(tf.train.match_filenames_once(filequeuelist))  # 生成一個檔案隊列

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(fileName_Queue)  # 傳回檔案名和檔案
    features = tf.parse_single_example(serialized_example,
                                     features={
                                        'label': tf.FixedLenFeature([], tf.string),
                                        'image': tf.FixedLenFeature([], tf.string)
                                     })  # 将image資料個label提取出來
    img = tf.decode_raw(features['image'], tf.uint8)
    img = tf.reshape(img, [rows, cols, IMAGE_CHANNEL])  # 将圖檔的reshape為128*128的3通道圖檔
    #img = tf.cast(img, tf.float32) * (1.0 / 255) - 0.5  # 在流中抛出img張量
    label = tf.cast(features['label'], tf.string)
    return img, label

def dispaly_image(filequeuelist, save_dir, sess):
    # 建立檔案存放目錄
    if not os.path.exists(save_dir):
        print("目錄 %s 不存在,自動建立中..." % (save_dir))
        os.makedirs(save_dir)
    # 生成每個檔案的路徑
    fileName_Queue = tf.train.string_input_producer(tf.train.match_filenames_once(filequeuelist))# 生成一個檔案隊列

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(fileName_Queue)# 傳回檔案名和檔案
    features = tf.parse_single_example(serialized_example,
                        features={
                           'label': tf.FixedLenFeature([], tf.int64),
                           'image': tf.FixedLenFeature([], tf.string)
                        })# 取出包含image和label的feature對象
    img = tf.decode_raw(features['image'], tf.uint8)
    img = tf.reshape(img, [128, 128, 3]) # 将圖檔的reshape為128*128的3通道圖檔
    label = tf.cast(features['label'], tf.int32)
    #print(img.shape)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)

    #print("----------",sess.run(img.shape),sess.run(label.shape))
    for i in range(5):
        example, l = sess.run([img, label])  # 在會話中取出image和label
        #print (example.shape, l.shape)
        # 變量名同名的話要注意
        #img = Image.fromarray(exaple, 'RBG')
        image=Image.fromarray(example, 'RGB')#這裡Image是之前提到的
        path = r"%s\%s\%s.jpg" % (save_dir,l,i)
        image.save(path)#存下圖檔
        #print(example, l)
    coord.request_stop()
    coord.join(threads)

if __name__ == "__main__":
    sess = tf.InteractiveSession()
    training_dataset = defaultdict(list)
    testing_dataset = defaultdict(list)
    # 圖檔分類存放的源路徑
    cwd = r"G:\AI\Images\n02085620-Chihuahua\*.jpg"
    # 解析tfrecord檔案後圖檔存放的路徑
    save_dir = r"G:\AI\test"
    #split_image_dataset(cwd, training_dataset, testing_dataset)
    # 将圖檔轉成tfrecord檔案(圖檔會變成120*120*3的格式)
    makeTFRecord(training_dataset, "F:/TS/TS_p_c/test/training-images/training-image", sess)
    makeTFRecord(testing_dataset, "F:/TS/TS_p_c/test/testing-images/testing-image", sess)
    # 将tfrecord檔案轉成圖檔(圖檔會變成120*120*3的格式)可以自己在源碼中修改圖檔大小
    [img, label] = read_and_decode("F:/TS/TS_p_c/output/training-images/training-image-0.tfrecords")
    #img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=18, capacity=2000, min_after_dequeue=100,num_threads=2)
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    # print("----------",sess.run(img.shape),sess.run(label.shape))
    for i in range(1000):
        example, l = sess.run([img, label])  # 在會話中取出image和label
        l = l.decode()
        #print(l)
        # 變量名同名的話要注意
        image = Image.fromarray(example, 'RGB')  # 這裡Image是之前提到的
        path = r"%s\%s\%s.jpg" % (save_dir, l, i)
        if not os.path.exists(r"%s\%s" % (save_dir, l)):
            print("目錄 %s\%s 不存在,自動建立中..." % (save_dir, l))
            os.makedirs(r"%s\%s" % (save_dir, l))
        image.save(path)  # 存下圖檔
        # print(example, l)
    coord.request_stop()
    coord.join(threads)