天天看點

将分類網絡應用在android中 part2,用自己的訓練結果應用android目錄準備工作編譯應用代碼解讀

目錄

準備工作

編譯應用

代碼解讀

準備工作

1.儲存checkpoint

可以參考之前的一篇利用tf slim進行分類網絡訓練的部落格,部落格位址,如果按照裡面的操作步驟進行訓練網絡,我們會得到儲存下來的checkpoint檔案。

model.ckpt-5000.data-00000-of-00001 --> 儲存了目前參數值
model.ckpt-5000.index --> 儲存了目前參數名
model.ckpt-5000.meta --> 儲存了目前graph結構圖
           

這樣訓練的腳本就直接幫我們完成了checkpoint的儲存。

但是如果是自己實作的網絡結構和網絡訓練,那我們需要使用下面的代碼來儲存checkpoint,然後我們同樣也會得到這三類檔案。

saver = tf.train.Saver()
saver.save(sess, './data/train_logs_1/model.chkp')
           

2.根據meta生成freezed pb

下面就是要根據生成的checkpoint檔案來生成freezed protobuf檔案。什麼是freezed呢?其實就是将參數的值和graph結合起來儲存成pb檔案,這樣後續使用的時候就隻需要直接輸入input進行計算就好了,也不用還原網絡結構。當然pb檔案裡面其實都是二進制的資訊,也無法還原網絡結構的,我們在運算的時候隻能按照裡面記錄的運算方式進行計算。

import tensorflow as tf

meta_path = './data/train_logs_1/model.ckpt-5000.meta' # Your .meta file
output_node_names = ['MobilenetV2/Predictions/Reshape_1']    # Output nodes

with tf.Session() as sess:

    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess, tf.train.latest_checkpoint('./data/train_logs_1'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('./data/train_logs_1/freeze_graph_5000.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())
           

上面的代碼可以儲存一個freezed_graph.pb檔案。在使用convert_variables_to_constants的時候需要一個output_node_names,如果不知道output的具體名字可以先用如下的代碼檢視

import tensorflow as tf

meta_path = './data/train_logs_1/model.ckpt-5000.meta' # Your .meta file

with tf.Session() as sess:

    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess, tf.train.latest_checkpoint('./data/train_logs_1'))

    graph = tf.get_default_graph()

    with open('./data/train_logs_1/operations_5000.txt', 'wb') as f:
        for op in graph.get_operations():
            f.writelines(str(op.name) + ',' + str(op.values()) + '\n')
           

這裡會将graph裡面包含的所有操作列印出來,因為内容比較多,是以存入檔案友善檢視,可以從operations_5000.txt中看到在loss之前最後一個輸出就是MobilenetV2/Predictions/Reshape_1,是以我們在convert_variables_to_constants中填入的output_node_names為['MobilenetV2/Predictions/Reshape_1'],其實可以填入好多個output,但是我們的分類網絡隻需要一個。

MobilenetV2/Predictions/Reshape_1,(<tf.Tensor 'MobilenetV2/Predictions/Reshape_1:0' shape=(32, 243) dtype=float32>,)
softmax_cross_entropy_loss/Rank,(<tf.Tensor 'softmax_cross_entropy_loss/Rank:0' shape=() dtype=int32>,)
           

3.測試freezed pb

前面兩個步驟後其實freezed pb檔案就已經儲存成功了,但是還需要測試一下我們儲存的pb檔案是否可靠,是否可以通過load這個pb檔案就進行預測。

首先我們需要load pb檔案

def load_graph(model_file):
  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())
  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph
           

接着取出input和output的tensor

input_operation = graph.get_operation_by_name(input_name)
  output_operation = graph.get_operation_by_name(output_name)
           

最後sess run就可以得到結果了

with tf.Session(graph=graph) as sess:
    results = sess.run(output_operation.outputs[0], {
        input_operation.outputs[0]: t
    })
  results = np.squeeze(results)
  print(results.shape)

  top_k = results.argsort()[-5:][::-1]
  labels = load_labels(label_file)
           

雖然測試pb檔案代碼很簡單,但是我們可能會遇到兩個坑。

第一個可能運作後會報錯。因為我們列印出來的results的shape是[32, 243],243是我們分類的類别數,但是為何有32個243的數組呢?

(1, 224, 224, 3)
2018-09-04 18:19:58.424800: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:03:00.0, compute capability: 6.1)
(32, 243)
Traceback (most recent call last):
  File "test_freeze_meta.py", line 80, in <module>
    print(labels[i], results[i])
TypeError: only integer scalar arrays can be converted to a scalar index
           

從我們前面列印出來的output的size可以看出MobilenetV2/Predictions/Reshape_1的shape是(32, 243),是因為32是我們之前training的batch size,而我們儲存meta的時候将這個size儲存下來了。是以導緻我們predict的時候出來的結果也是(32, 243),但是我們其實隻預測了一張圖檔。

如果遇到了這個問題,可以将training時的batch size改成1,基于前面的checkpoint再進行一次訓練生成新的meta檔案,然後重複上面的步驟進行操作即可。比如我這邊會用下面的指令接着進行training

python train_image_classifier.py \
    --train_dir=./data/train_logs_1 \
    --dataset_dir=./data/mydata \
    --dataset_name=mydata \
    --dataset_split_name=train \
    --model_name=mobilenet_v2 \
    --train_image_size=224 \
    --batch_size=1
           

會生成model.ckpt-5001.data-00000-of-00001,model.ckpt-5001.index和model.ckpt-5001.meta檔案,然後重新生成operations_5001.txt,可以發現裡面shape已經變過來了

MobilenetV2/Predictions/Reshape_1,(<tf.Tensor 'MobilenetV2/Predictions/Reshape_1:0' shape=(1, 243) dtype=float32>,)
softmax_cross_entropy_loss/Rank,(<tf.Tensor 'softmax_cross_entropy_loss/Rank:0' shape=() dtype=int32>,)
           

後面進行預測就不會有報錯。但是其實并沒有結束,因為我們還可能遇到第二個坑。

預測的結果特别不準确,而且多跑幾次會發現每次結果都不一樣。這是因為batch normalization和dropout的随機性導緻的

with slim.arg_scope([slim.batch_norm, slim.dropout],
                        is_training=is_training):
           

代碼中在搭建網絡的時候很清楚的對這兩種操作區分了是否是training狀态。是以我們現在的做法是搭建網絡的時候傳入這個參數為false,然後進行一次training生成checkpoint 5002。

network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        weight_decay=FLAGS.weight_decay,
        is_training=False)
           

如果對5002 checkpoint進行評估會發現結果是非常不準确的,沒關系,我們隻需要用到他的meta檔案。

然後在用第二步中的操作生成freezed pb檔案

meta_path = './data/train_logs_1/model.ckpt-5002.meta' # Your .meta file

#這裡restore的checkpoint需要是準确率比較高的checkpoint,比如ckpt-5001
#如果讓latest_checkpoint取到的是5001呢,很簡單,修改train_logs_1目錄下的checkpoint檔案
#修改model_checkpoint_path: "model.ckpt-5001",這樣就會自動取5001為checkpoint來恢複資料了
saver.restore(sess, tf.train.latest_checkpoint('./data/train_logs_1'))
           

然後再進行預測,一切都正常了。比如預測file_name = "./backup/mydata/km335_back/km335_back.jpg"檔案,結果是

(243,)
('150:km335_back', 0.99930513)
('151:km335_front', 0.00019936822)
('191:km711_front', 0.00018467742)
('193:km712_front', 0.00015364563)
('220:kmmerge123_back', 4.9729293e-05)
           

并不是每個人都會遇到這兩個問題,如果是自己搭建網絡,自己儲存checkpoint我想是可以避免的,但是tf-slim是用slim.learning.train接口進行訓練和儲存checkpoint,是以儲存形式不太可控。

代碼實作:freeze_meta.py   test_freeze_meta.py

編譯應用

如果按照上一篇博文(連結)進行了實操,那這一步就會非常容易了。

首先将上一步編譯出來的pb檔案,和我們分類的label檔案拷貝放入tensorflow/examples/android/assets

然後修改ClassifierActivity.java中的代碼如下

private static final String INPUT_NAME = "MobilenetV2/input";
  private static final String OUTPUT_NAME = "MobilenetV2/Predictions/Reshape_1";  

  private static final String MODEL_FILE = "file:///android_asset/freeze_graph_5002.pb";
  private static final String LABEL_FILE = "file:///android_asset/labels.txt";
           

接着用bazel進行編譯

bazel build //tensorflow/examples/android:tensorflow_coin
           

生成的apk放在了bazel-bin/tensorflow/examples/android目錄下,安裝啟動即可。

但是我們可能會遇到另外一個坑,安裝apk後啟動會crash,從adb log看到的錯誤是

09-05 14:07:58.741 16317 16441 E AndroidRuntime: java.lang.IllegalArgumentException: Cannot assign a device for operation 'MobilenetV2/input': Operation was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device.
09-05 14:07:58.741 16317 16441 E AndroidRuntime: 	 [[Node: MobilenetV2/input = Identity[T=DT_FLOAT, _device="/device:GPU:0"](fifo_queue_Dequeue)]]
           

這是因為我們訓練的時候用的是GPU,儲存的pb中指定了用GPU進行load,而手機中隻有CPU,是以所發生錯誤。

改動方法是訓練的時候添加一個flag --clone_on_cpu=True,就可以将我們的meta儲存device指定CPU。

代碼解讀

android代碼中關于分類網絡的主要是兩個檔案,一個是ClassifierActivity.java,另一個是TensorFlowImageClassifier.java

1.ClassifierActivity.java

這個檔案主要負責camera的preview,将preview中的圖檔傳遞給TensorFlowImageClassifier進行分類網絡的預測,最後顯示預測結果。

# 建立classifier執行個體
classifier =
    TensorFlowImageClassifier.create(
        getAssets(),
        MODEL_FILE,
        LABEL_FILE,
        INPUT_SIZE,
        IMAGE_MEAN,
        IMAGE_STD,
        INPUT_NAME,
        OUTPUT_NAME);

# 調用recognizeImage進行圖像識别
final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);

# 顯示預測結果
resultsView.setResults(results);
           

2.TensorFlowImageClassifier.java

主要負責調用TensorFlowInferenceInterface類的接口進行預測。

# 執行個體化TensorFlowInferenceInterface,同時會将model載入
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);

# 傳入input的image資料
inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);

# 進行計算
inferenceInterface.run(outputNames, logStats);

#取出計算結果
inferenceInterface.fetch(outputName, outputs);
           

以上就是全部内容,如果在操作過程中遇到了任何問題可以給我留言,謝謝閱讀。