天天看點

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

點選即可參與機器學習PAI-DSW動手實驗室 點選可下載下傳完整電子書《阿裡雲機器學習PAI-DSW入門指南》

雖然已經 9102 年了MNIST手寫資料集也早已經被各路神仙玩出了各種

花樣

,比如其中比較秀的有用MINST訓練手寫日語字型的。但是目前還是很少有整體的将訓練完之後的結果部署為一個可使用的服務的。大多數還是停留在最終Print出一個Accuracy。

這一次,借助阿裡雲的PAI-DSW來快速建構訓練一個手寫模型并且部署出一個生産可用級别的服務的教程讓大家可以在其他的産品中調用這個服務作出更加有意思的項目。

這篇文章裡我們先講講如何建構訓練并導出這個手寫字型識别的模型。整個教程的代碼基于Snapchat的ML大佬 Aymeric Damien 的

Tensorflow 入門教程系列

第一步: 下載下傳代碼

首先我們可以把代碼Clone到本地或者直接Clone到DSW的執行個體。如何Clone到DSW執行個體的方法可以參考我的

這篇文章

。Clone完代碼之後我們還需要準備訓練所需要的資料集這邊可以直接從

Yann Lecun的網站

下載下傳。我這邊然後我們可先運作一遍看一下效果。

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

我們可以看到代碼Clone下來之後直接運作就已經幫我們訓練出了model并且給出了現在這個Model的精度。在500個batch之後準确率達到了95%以上而且基于GPU的DSW執行個體訓練這500個Batch隻需要十幾秒的時間。

第二步: 修改部分代碼使得可以自動導出SavedModel

這一步就是比較重要的地方了我們第一個需要關注的就是目前的這個Model裡面的Input和Output.

Input還比較清楚我們直接找所有placeholder就可以了。

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

Output這一塊就比較複雜了,在目前的model裡我們可以看到output并不是直接定義的Y而是softmax之後的prediction

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

找到了這些之後就比較簡單了。首先我們建立一個 Saver , 它可以幫助我們儲存所有的tf變量以便之後導出模型使用

# 'Saver' op to save and restore all the variables
saver = tf.train.Saver()           

然後我們在模型訓練的session結束的時候導出模型就好了。我們可以通過以下這段代碼來導出我們訓練好的模型。

import datetime
# 聲明導出模型路徑這邊加入了時間作為路徑名 這樣每次訓練的時候就可以儲存多個版本的模型了
export_path = "./model-" + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')

# 儲存訓練的日志檔案友善如果出問題了我們可以用 tensorboard 來可視化神經網絡排查問題
tf.summary.FileWriter('./graph-' + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S') , sess.graph)

# 建構我們的Builder
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

# 聲明各種輸入這裡有一個X和一個keep_prob作為輸入然後
tensor_info_x = tf.saved_model.utils.build_tensor_info(X)
tensor_info_keep_prob = tf.saved_model.utils.build_tensor_info(keep_prob)
tensor_info_y = tf.saved_model.utils.build_tensor_info(prediction)

prediction_signature = (
    tf.saved_model.signature_def_utils.build_signature_def(
        # 聲明輸入
        inputs={
            'images': tensor_info_x,
            'keep_prob' : tensor_info_keep_prob
        },
        # 聲明輸出
        outputs={
            'scores': tensor_info_y
        },
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
    )
)



legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
    sess, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
        'predict_images':
            prediction_signature,
    },
    legacy_init_op=legacy_init_op)
# 儲存模型
builder.save()           

我們可以把這段代碼插在這裡這樣訓練完成的時候就會自動導出了。

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

導出之後應該會有如下的檔案結構我們也可以在左邊的檔案管理器中檢視。

./model-2019-05-20_13:50:26
├── saved_model.pb
└── variables
    ├── variables.data-00000-of-00001
    └── variables.index

1 directory, 3 files           

第三步: 部署我們的模型

終于到了可以部署的階段了。但是在部署之前先别那麼着急建議用

tensorboard

把訓練日志下載下傳到本地之後看一下。

這一步除了可以可視化的解釋我們的模型之外還可以幫助我們理清我們的模型的輸入和輸出分别是什麼。

這邊我先在有日志檔案的路徑打開一個tensorboard 通過這個指令

tensorboard --logdir ./           

然後我們在遊覽器裡輸入預設位址 localhost:6006 就可以看到了。

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

從這個圖裡也可以看到我們的這個Model裡有2個輸入源分别叫做images和keep_prob。并且點選它們之後我們還能看到對應的資料格式應該是什麼樣的。不過沒有辦法使用 Tensorboard 的同學也不用擔心因為

EAS

這個産品也為我們提供了構造請求的方式。這一次部署我們先使用WEB界面來部署我們的服務這一步也可以通過

EASCMD

來實作之後我會再寫一篇如何用好EASCMD的文章。

我們可以把模型檔案下載下傳完之後用zip打包然後到PAI産品的控制台點選EAS-模型線上服務。

ZIP打包可以用這個指令如果你是Unix的使用者的話

zip -r model.zip path/to/model_files           

進入EAS之後我們點選模型部署上傳

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

然後繼續配置我們的processor這一次因為我們是用tensorflow訓練的是以選擇Tensorflow

然後資源選擇CPU有需要的同學可以考慮GPU然後上傳我們的模型檔案。

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

點選下一步我們選建立服務然後給我們的服務起個名字,并且配置資源數量。

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

然後最後确認一下就可以點選部署了。

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

第四步: 調試我們的模型

回到EAS的控制台我們可以看到我們的服務正在被建構中。等到狀态顯示Running的時候我們就可以開始調試了。

我們可以先點選線上調試。

四步訓練出自己的CNN手寫識别模型 | 《阿裡雲機器學習PAI-DSW入門指南》

會讓我們跳轉到一個Debug 接口的頁面。什麼都不需要填直接點選送出我們就可以看到服務的資料格式了。

然後我們用一段python2的代碼來調試這個剛剛部署完的服務。python3的SDK暫時還在研發中。注意要把下面的

app_key, app_secret, url 換成我們剛剛部署好的内容。點選模型名字就可以看見了。

其中測試圖檔的資料大家可以在

下載下傳到。

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import json

from urlparse import urlparse
from com.aliyun.api.gateway.sdk import client
from com.aliyun.api.gateway.sdk.http import request
from com.aliyun.api.gateway.sdk.common import constant
from pai_tf_predict_proto import tf_predict_pb2

import cv2
import numpy as np

with open('9.jpg', 'rb') as infile:
    buf = infile.read()
    # 使用numpy将位元組流轉換成array
    x = np.fromstring(buf, dtype='uint8')
    # 将讀取到的array進行圖檔解碼獲得28 × 28的矩陣
    img = cv2.imdecode(x, cv2.IMREAD_UNCHANGED)
    # 由于預測服務API需要長度為784的一維向量将矩陣reshape成784
    img = np.reshape(img, 784)

def predict(url, app_key, app_secret, request_data):
    cli = client.DefaultClient(app_key=app_key, app_secret=app_secret)
    body = request_data
    url_ele = urlparse(url)
    host = 'http://' + url_ele.hostname
    path = url_ele.path
    req_post = request.Request(host=host, protocol=constant.HTTP, url=path, method="POST", time_out=6000)
    req_post.set_body(body)
    req_post.set_content_type(constant.CONTENT_TYPE_STREAM)
    stat,header, content = cli.execute(req_post)
    return stat, dict(header) if header is not None else {}, content


def demo():
    # 輸入模型資訊,點選模型名字就可以擷取到了
    app_key = 'YOUR_APP_KEY'
    app_secret = 'YOUR_APP_SECRET'
    url = 'YOUR_APP_URL'

    # 構造服務
    request = tf_predict_pb2.PredictRequest()
    request.signature_name = 'predict_images'
    request.inputs['images'].dtype = tf_predict_pb2.DT_FLOAT  # images 參數類型
    request.inputs['images'].array_shape.dim.extend([1, 784])  # images參數的形狀
    request.inputs['images'].float_val.extend(img)  # 資料

    request.inputs['keep_prob'].dtype = tf_predict_pb2.DT_FLOAT  # keep_prob 參數的類型
    request.inputs['keep_prob'].float_val.extend([0.75])  # 預設填寫一個

    # å°†pbåºåˆ—化æˆstringè¿›è¡Œä¼ è¾“
    request_data = request.SerializeToString()

    stat, header, content = predict(url, app_key, app_secret, request_data)
    if stat != 200:
        print 'Http status code: ', stat
        print 'Error msg in header: ', header['x-ca-error-message'] if 'x-ca-error-message' in header else ''
        print 'Error msg in body: ', content
    else:
        response = tf_predict_pb2.PredictResponse()
        response.ParseFromString(content)
        print(response)


if __name__ == '__main__':
    demo()           

運作這個python代碼然後我們會得到

outputs {
  key: "scores"
  value {
    dtype: DT_FLOAT
    array_shape {
      dim: 1
      dim: 10
    }
    float_val: 0.0
    float_val: 0.0
    float_val: 0.0
    float_val: 0.0
    float_val: 0.0
    float_val: 0.0
    float_val: 0.0
    float_val: 0.0
    float_val: 0.0
    float_val: 1.0
  }
}           

我們可以看到從0開始數的最後一個也就是第9個的結果是1 其他都是0 說明我們的結果是9和我們輸入的一樣。這樣我們就簡單輕松的建構了一個線上服務能夠将使用者的圖檔中手寫數字識别出來。配合其他Web架構或者更多的東西我們就可以作出更好玩的玩具啦。