天天看點

Tensorflow: tflite_convert and interpreter一 . Convert ".pb"  to ".tflite"

一 . Convert ".pb"  to ".tflite"

Method 1: Python Convert 

graph_def_file = "retrained_graph_mobilenet_v2_1.4_224.pb"
input_arrays = ["input"]
# input_shape = 
output_arrays = ["final_result"]

converter = tf.contrib.lite.TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("retrained_graph_mobilenet_v2_1.4_224_python.tflite", "wb").write(tflite_model)
           

Method 2: CML Convert

IMAGE_SIZE=224
tflite_convert \
  --graph_def_file=tf_files/retrained_graph_mobilenet_v2_1.4_224.pb \
  --output_file=tf_files/retrained_graph_mobilenet_v2_1.4_224_batch3.lite \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --input_shape=7,${IMAGE_SIZE},${IMAGE_SIZE},3 \
  --input_array=input \
  --output_array=final_result \
  --inference_type=FLOAT \
  --input_data_type=FLOAT

## or python3 -m tensorflow.contrib.lite.python.tflite_convert \
           

二. Interpreter ".tflite"

#image data pipline:

def read_tensor_from_image_file(file_name, input_height=299, input_width=299, input_mean=0, input_std=255):
  input_name = "file_reader"
  output_name = "normalized"
  file_reader = tf.read_file(file_name, input_name)

  if file_name.endswith(".png"):
    image_reader = tf.image.decode_png(file_reader, channels = 3,
                                       name='png_reader')
  elif file_name.endswith(".gif"):
    image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
                                                  name='gif_reader'))
  elif file_name.endswith(".bmp"):
    image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
  else:
    image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
                                        name='jpeg_reader')
  float_caster = tf.cast(image_reader, tf.float32)
  dims_expander = tf.expand_dims(float_caster, 0);
  resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
  normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
  sess = tf.Session()
  result = sess.run(normalized)

  return result



# args:
input_height = 224
input_width = 224
input_mean = 128
input_std = 128





#---------------------------------
# main script
#---------------------------------

import numpy as np
import tensorflow as tf

# Load TFLite model and allocate tensors.
interpreter = tf.contrib.lite.Interpreter(model_path="tf_lite_for_image/tf_files/retrained_graph_mobilenet_v2_1.4_224_batch3.lite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


file_name_list = ["a.jpg","b.jpg"]

input_batch = []
for file_name in file_name_list:
    temp = read_tensor_from_image_file(file_name,
                                      input_height=input_height,
                                      input_width=input_width,
                                      input_mean=input_mean,
                                      input_std=input_std)
    input_batch.append(np.squeeze(temp))
input_batch = np.array(input_batch)


interpreter.set_tensor(input_details[0]['index'], input_batch)


interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
           

Note:

specify the input shape for converted lite,  for example  7 images.  It's fast to pass 7 samples through the net to predict the classes:

       --input_shape=7,${IMAGE_SIZE},${IMAGE_SIZE},3 

Then the "file_name_list" should have 7 samples.

the output_data shape is (7, classes_num)

繼續閱讀