天天看點

圖像預處理tf.image.resize_images遇到的坑

image = tf.image.resize_images(image, height, width, method = np.random.randint(4))

報錯:ValueError: ‘size’ must be a 1-D Tensor of 2 elements

改為:image = tf.image.resize_images(img, new_shape, method = np.random.randint(4) 後可以正常運作

以下為完整的圖像預處理代碼:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def distort_color(image, color_ordering = 0):
    if color_ordering == 0:
        image = tf.image.random_brightness(image, max_delta = 32. / 255.)
        image = tf.image.random_saturation(image, lower = 0.5, upper = 1.5)
        image = tf.image.random_hue(image, max_delta = 0.2)
        image = tf.image.random_contrast(image, lower = 0.5, upper = 1.5)
    elif color_ordering == 1:
        image = tf.image.random_saturation(image, lower = 0.5, upper = 1.5)
        image = tf.image.random_brightness(image, max_delta = 32. / 255.)
        image = tf.image.random_contrast(image, lower = 0.5, upper = 1.5)
        image = tf.image.random_hue(image, max_delta = 0.2)
    elif color_ordering == 2:
        image = tf.image.random_hue(image, max_delta = 0.2)
        image = tf.image.random_saturation(image, lower = 0.5, upper = 1.5)
        image = tf.image.random_brightness(image, max_delta = 32. / 255.)
        image = tf.image.random_contrast(image, lower = 0.5, upper = 1.5)
    return tf.clip_by_value(image, 0.0, 1.0)

def preprocess_for_train(image, new_shape):
    #if bbox is None:
        #bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype = tf.float32, shape = [1, 1, 4])
    #if image.dtype != tf.float32:
    image = tf.image.convert_image_dtype(image, dtype = tf.float32)
        
    #bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(tf.shape(image), bounding_boxes = bbox)
    #distorted_image = tf.slice(image, bbox_begin, bbox_size)
    
    distorted_image = tf.image.resize_images(image, new_shape, method = np.random.randint(4))
    
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    
    distorted_image = distort_color(distorted_image, np.random.randint(2))
    
    return distorted_image

image_raw_data = tf.gfile.FastGFile("/home/diana/test/dog/g1.jpg").read()
with tf.Session() as sess:
    img_data = tf.image.decode_jpeg(image_raw_data)
    print(img_data.eval())
    #boxes = tf.constant([[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]])
    
    for i in range(6):
        result = preprocess_for_train(img_data, [32, 32])
        plt.imshow(result.eval())
        plt.show()
           

運作結果:

圖像預處理tf.image.resize_images遇到的坑
圖像預處理tf.image.resize_images遇到的坑
圖像預處理tf.image.resize_images遇到的坑
圖像預處理tf.image.resize_images遇到的坑
圖像預處理tf.image.resize_images遇到的坑
圖像預處理tf.image.resize_images遇到的坑
[[[166 135  89]
  [166 135  89]
  [166 135  91]
  ..., 
  [ 58  74  64]
  [ 57  74  64]
  [ 57  74  64]]

 [[166 135  89]
  [166 135  89]
  [166 135  91]
  ..., 
  [ 58  74  64]
  [ 57  74  64]
  [ 57  74  64]]

 [[163 134  90]
  [163 134  90]
  [163 134  90]
  ..., 
  [ 58  74  64]
  [ 57  74  64]
  [ 57  74  64]]

 ..., 
 [[215 195 171]
  [213 195 171]
  [213 195 171]
  ..., 
  [129 162 177]
  [125 161 177]
  [122 159 175]]

 [[216 198 176]
  [215 199 176]
  [215 199 176]
  ..., 
  [133 167 179]
  [129 165 179]
  [127 163 177]]

 [[217 200 180]
  [217 200 180]
  [217 202 181]
  ..., 
  [135 169 181]
  [131 167 181]
  [130 166 180]]]
      

繼續閱讀