天天看點

如何部署自己的SSD檢測模型到Android TFLite上

TensorFlow Object Detection API 上提供了使用SSD部署到TFLite運作上去的方法, 可是這套API封裝太死闆, 如果你要自己實作了一套SSD的訓練算法,應該怎麼才能部署到TFLite上呢?

首先,抛開後處理的部分,你的SSD模型(無論是VGG-SSD和Mobilenet-SSD), 你最終的模型的輸出是對class_predictions和bbox_predictions; 并且是encoded的

Encoding的方式:

class_predictions: M個Feature Layer, Feature Layer的大小(寬高)視網絡結構而定; 每個Feature Layer有Num_Anchor_Depth_of_this_layer x Num_classes個channels

box_predictions:   M個Feature Layer; 每個Feature Layer有Num_Anchor_Depth_of_this_layer x 4個channes 這4個channel分别代表dy,dx,h,w, 即bbox中心距離anchor中心坐标的偏移量和寬高

注:通常,為了平衡loss之間的大小, 不會直接編碼dy,dx,w,h的原始值,而是dy/anchor_h*scale0, dx/anchor_w*scale0, log(h/anchor_h)*scale1, log(w/anchor_w)*scale1, 也就是偏移量的絕對值除anchor寬高得到相對值,然後再乘上一個scale, 經驗值 scale0取5,scale1取10; 對于h,w是對得到相對值後先取log再乘以scale, h/anchor_h的範圍在1附近, 取log後可以轉換到0附近;是以在解碼的時候需要做對應相反的變換;

在後面TFLite_Detection_PostProcess的Op實作裡就有這麼一段:

如何部署自己的SSD檢測模型到Android TFLite上

然後我們需要的是做的是decode出來對每個class的confidence和location的預測值

後處理

在Object Detection API的 export_tflite_ssd_graph_lib.py檔案中,你可以看到,它差別與直接freeze pb的操作就在于最後替換了後處理的部分;

Plain Text

Bash

Basic

C

C++

C#

CSS

Diff

Git

go

GraphQL

HTML

HTTP

Java

JavaScript

JSON

JSX

Kotlin

Less

Makefile

Markdown

MATLAB

Nginx

Objective-C

Pascal

Perl

PHP

PowerShell

Ruby

Protobuf

Python

R

Scala

Shell

SQL

Swift

TypeScript

VB.net

XML

YAML

KaTeX

Mermaid

PlantUML

Flow

Graphviz

frozen_graph_def = exporter.freeze_graph_with_def_protos(

input_graph_def=tf.get_default_graph().as_graph_def(),

input_saver_def=input_saver_def,

input_checkpoint=checkpoint_to_use,

output_node_names=','.join([

'raw_outputs/box_encodings', 'raw_outputs/class_predictions',

'anchors'

]),

restore_op_name='save/restore_all',

filename_tensor_name='save/Const:0',

clear_devices=True,

output_graph='',

initializer_nodes='')

# Add new operation to do post processing in a custom op (TF Lite only)

if add_postprocessing_op:

transformed_graph_def = append_postprocessing_op(

frozen_graph_def, max_detections, max_classes_per_detection,

nms_score_threshold, nms_iou_threshold, num_classes, scale_values)

else:

# Return frozen without adding post-processing custom op

transformed_graph_def = frozen_graph_def

後處理的部分,其實看代碼也很簡單,就是增加了一個叫TFLite_Detection_PostProcess的node,用于解碼和非極大抑制. 這個node的輸入就是上面提到的box_predictions和class_predictions, 還有anchors的編碼; 用這個node的目的隻TFLite并不支援tf.contrib.image.non_max_surpression操作

Reshape過程:

這裡需要明确,TFLite_Detection_PostProcess 這個op對raw_outputs/box_encodings, raw_outputs/class_predictions, anchors的Shape是有一個定制要求的

raw_outputs/box_encodings.shape=[1, num_anchors,4]

raw_outputs/class_predictions.shape=[1, num_anchors,num_classes+1]

anchors.shape=[1,num_anchors,4]

這裡需要注意:1, 這三個都必須是3維的Tensor; 2.raw_outputs/class_predictions.shape的最後一個次元是包含background的classes, 也就是是num_classes+1; TFLite_Detection_PostProcess還有一個參數num_classes, 這個參數值是不包含background的, 是以也就導緻TFLite_Detection_PostProcess的輸出的class index是從0計數的;

with tf.variable_scope('raw_outputs'):

cls_pred = [tf.reshape(pred, [-1, num_classes]) for pred in cls_pred]

location_pred = [tf.reshape(pred, [-1, 4]) for pred in location_pred]

cls_pred = tf.concat(cls_pred, axis=0)

location_pred = tf.expand_dims(tf.concat(location_pred, axis=0),0, name='box_encodings')

cls_pred=tf.nn.softmax(cls_pred)

tf.identity(tf.expand_dims(cls_pred,0), name='class_predictions')

這段代碼就是用來reshape成要求的輸入的, 需要注意的是對class_prediction需要做依次softmax或者sigmoid, 具體選擇哪種取決于你是否允許一個anchor對應多個類;

對于anchors, 這其實是一constant的值:

num_anchors = anchor_cy.get_shape().as_list()

with tf.Session() as sess:

y_out, x_out, h_out, w_out = sess.run([anchor_cy, anchor_cx, anchor_h, anchor_w])

encoded_anchors = tf.constant(

np.transpose(np.stack((y_out, x_out, h_out, w_out))),

dtype=tf.float32,

shape=[num_anchors[0], 4])

注意: 之前我使用tf.stack合成這個值的時候發現,TFLite隻支援axis=0的時候的tf.stack, 否則就會轉換是吧

導出pb

添加完後處理,既可以導出一個帶有後處理功能的pb檔案了; 如果你不添加後處理,把它放在CPU上後續去做,其實也可以省去不少麻煩;

binary_graph = os.path.join(output_dir, 'tflite_graph.pb')

with tf.gfile.GFile(binary_graph, 'wb') as f:

f.write(transformed_graph_def.SerializeToString())

txt_graph = os.path.join(output_dir, 'tflite_graph.pbtxt')

with tf.gfile.GFile(txt_graph, 'w') as f:

f.write(str(transformed_graph_def))

注意: 導出的pb如果包含後處理, 是沒辦法用正常的TF執行的,必須轉成tflite執行

導出tflite

bazel run --config=opt tensorflow/contrib/lite/toco:toco -- \

--input_file=$OUTPUT_DIR/tflite_graph.pb \

--output_file=$OUTPUT_DIR/detect.tflite \

--input_shapes=1,300,300,3 \

--input_arrays=normalized_input_image_tensor \

--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \

--inference_type=QUANTIZED_UINT8 \

--mean_values=128 \

--std_values=128 \

--change_concat_input_ranges=false \

--allow_custom_ops

or

bazel run -c opt tensorflow/lite/toco:toco -- \

--inference_type=FLOAT \

導出的過程中,可能遇到Converting unsupported operation: TFLite_Detection_PostProcess 這個提示, 正常如果是TF在1.10以上就忽略這個提示好了

然後你可以先用python的程式加載這個tflite去測試一下

注意: 這時候會發現一個問題, TFLite_Detection_PostProcess的NMS操作是忽略類标簽的,如果你設定max_classes_per_detection=1; 但是如果你設定成>1的值, 會發現它吧background的标簽也算進來了, 導緻出來很多誤檢測的bbox;

部署Android

然後,你可以嘗試部署到Android上, 在不使用NNAPI的時候正常,但是如果是NNAPI就需要自己實作相關操作了,否則會crash掉