def data_loader(FLAGS):
with tf.device('/cpu:0'):
# Define the returned data batches
Data = collections.namedtuple('Data', 'paths_LR, paths_HR, inputs, targets, image_count, steps_per_epoch')
#Check the input directory 資料目錄核對
if (FLAGS.input_dir_LR == 'None') or (FLAGS.input_dir_HR == 'None'):
raise ValueError('Input directory is not provided')
if (not os.path.exists(FLAGS.input_dir_LR)) or (not os.path.exists(FLAGS.input_dir_HR)):
raise ValueError('Input directory not found')
image_list_LR = os.listdir(FLAGS.input_dir_LR)
image_list_LR = [_ for _ in image_list_LR if _.endswith('.png')]
if len(image_list_LR)==0:
raise Exception('No png files in the input directory')
# 建立圖像Tensor
image_list_LR_temp = sorted(image_list_LR)
image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp]
image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp]
image_list_LR_tensor = tf.convert_to_tensor(image_list_LR, dtype=tf.string)
image_list_HR_tensor = tf.convert_to_tensor(image_list_HR, dtype=tf.string)
with tf.variable_scope('load_image'):
# define the image list queue
# image_list_LR_queue = tf.train.string_input_producer(image_list_LR, shuffle=False, capacity=FLAGS.name_queue_capacity)
# image_list_HR_queue = tf.train.string_input_producer(image_list_HR, shuffle=False, capacity=FLAGS.name_queue_capacity)
#print('[Queue] image list queue use shuffle: %s'%(FLAGS.mode == 'Train'))
output = tf.train.slice_input_producer([image_list_LR_tensor, image_list_HR_tensor],
shuffle=False, capacity=FLAGS.name_queue_capacity)
# Reading and decode the images
reader = tf.WholeFileReader(name='image_reader')
image_LR = tf.read_file(output[0])
image_HR = tf.read_file(output[1])
input_image_LR = tf.image.decode_png(image_LR, channels=3)
input_image_HR = tf.image.decode_png(image_HR, channels=3)
input_image_LR = tf.image.convert_image_dtype(input_image_LR, dtype=tf.float32)
input_image_HR = tf.image.convert_image_dtype(input_image_HR, dtype=tf.float32)
assertion = tf.assert_equal(tf.shape(input_image_LR)[2], 3, message="image does not have 3 channels")
with tf.control_dependencies([assertion]):
input_image_LR = tf.identity(input_image_LR)
input_image_HR = tf.identity(input_image_HR)
# Normalize the low resolution image to [0, 1], high resolution to [-1, 1]
a_image = preprocessLR(input_image_LR)
b_image = preprocess(input_image_HR)
inputs, targets = [a_image, b_image]
# The data augmentation part 資料增強部分
with tf.name_scope('data_preprocessing'):
with tf.name_scope('random_crop'): #随機裁剪
# Check whether perform crop
if (FLAGS.random_crop is True) and FLAGS.mode == 'train':
print('[Config] Use random crop')
# Set the shape of the input image. the target will have 4X size
input_size = tf.shape(inputs)
target_size = tf.shape(targets)
offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[1], tf.float32) - FLAGS.crop_size)),
dtype=tf.int32)
offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[0], tf.float32) - FLAGS.crop_size)),
dtype=tf.int32)
if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size,
FLAGS.crop_size)
targets = tf.image.crop_to_bounding_box(targets, offset_h*4, offset_w*4, FLAGS.crop_size*4,
FLAGS.crop_size*4)
elif FLAGS.task == 'denoise':
inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size,
FLAGS.crop_size)
targets = tf.image.crop_to_bounding_box(targets, offset_h, offset_w,
FLAGS.crop_size, FLAGS.crop_size)
# Do not perform crop
else:
inputs = tf.identity(inputs)
targets = tf.identity(targets)
with tf.variable_scope('random_flip'): #随機翻轉
# Check for random flip:
if (FLAGS.flip is True) and (FLAGS.mode == 'train'):
print('[Config] Use random flip')
# Produce the decision of random flip
decision = tf.random_uniform([], 0, 1, dtype=tf.float32)
input_images = random_flip(inputs, decision)
target_images = random_flip(targets, decision)
else:
input_images = tf.identity(inputs)
target_images = tf.identity(targets)
if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
target_images.set_shape([FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
elif FLAGS.task == 'denoise':
input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
target_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3])
if FLAGS.mode == 'train':
paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.shuffle_batch([output[0], output[1], input_images, target_images],
batch_size=FLAGS.batch_size, capacity=FLAGS.image_queue_capacity+4*FLAGS.batch_size,
min_after_dequeue=FLAGS.image_queue_capacity, num_threads=FLAGS.queue_thread)
else:
paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.batch([output[0], output[1], input_images, target_images],
batch_size=FLAGS.batch_size, num_threads=FLAGS.queue_thread, allow_smaller_final_batch=True)
steps_per_epoch = int(math.ceil(len(image_list_LR) / FLAGS.batch_size))
if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet':
inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3])
elif FLAGS.task == 'denoise':
inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3])
return Data(
paths_LR=paths_LR_batch,
paths_HR=paths_HR_batch,
inputs=inputs_batch,
targets=targets_batch,
image_count=len(image_list_LR),
steps_per_epoch=steps_per_epoch
)