天天看點

使用NCNN在移動端部署深度學習模型

一、整體流程概覽

  1. 訓練模型,使用各種你熟悉的架構我用的是pytorch
  2. 将*.pth轉換成onnx, 優化onnx模型
  3. 使用轉換工具轉換成可供ncnn使用的模型
  4. 編譯ncnn架構,并編寫c++代碼調用上一步轉換的模型,得到模型的輸出結果,封裝成可供調用的類
  5. 使用JNIC++調用上一步C++封裝的類,提供出接口
  6. 在安卓端編寫java代碼再次封裝一次,供應用層調用

二、将*.pth轉換成onnx

使用pytorch自帶的torch.onnx即可,需要1.1版本以上,這裡有一點需要注意,torch的API有些是onnx不支援的,如果轉換的時候報錯就把模型裡的函數改成onnx支援的吧,有些文章裡說這裡可以設定opset_version=12來解決,但是這樣的話在後面轉換到ncnn或者mnn的時候造成轉換失敗,應該是ncnn還沒支援到更高版本的onnx的原因。在最後輸出之前有個torch.randn()函數,這裡的參數格式是[b,c, w,h]這裡也不是随便寫的,b固定是1了,你模型的輸入通道是多少就寫多少,後面的就是模型的輸入,這裡一旦固定了,後面在第5步的時候c++裡的輸入也就固定了

convet2onnx.py

# -*- coding:utf-8 -*-
# name: convert_onnx
# author: bqh
# datetime:2020/6/17 10:31
# =========================
import torch
def load_model(model, pretrained_path):
    print('Loading pretrained model from {}'.format(pretrained_path))
    pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)    
    if "state_dict" in pretrained_dict.keys():
        pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
    else:
        pretrained_dict = remove_prefix(pretrained_dict, 'module.')
    check_keys(model, pretrained_dict)
    model.load_state_dict(pretrained_dict, strict=False)
    return model
    
output_onnx = '../weights/output.onnx'
raw_weights = '../weights/model.pth'

# load weight
net = you_net()
net = load_model(net, raw_weights)
net.eval()
print('Finished loading model!')
device = torch.device("cuda")
net = net.to(device)

input_names = ["input0"]
output_names = ["output0"]
inputs = torch.randn(1, 3, 300, 300).to(device)
torch_out = torch.onnx._export(net, inputs, output_onnx, export_params=True, verbose=False,keep_initializers_as_inputs=True, input_names=input_names, output_names=output_names)
           

安裝onnx簡化工具

pip3 install onnx-simplifier onnxruntime
           

簡化onnx模型

這一步一定要做,否則後面轉onnx的時候會報錯

python3 -m onnxsim model.onnx model_sim.onnx
           

三、編譯NCNN架構

主要參考ncnn官網的教程即可,windows下編譯同上一篇的MNN的編譯都差不多,隻有一點需要說明,官網的教程上有vulkan-sdk的安裝然後打開-DNCNN_VULKAN=ON編譯選項。我一切照做後編譯出來的ncnn.lib在運作ncnn::Extractor ex = Net->create_extractor();這個函數後的所有操作之後,傳回的時候就報堆棧溢出錯誤,包括加載官網給出的例子全部報錯;後來不cmake的時候這個編譯選項不打開編譯出來的ncnn.lib就一切正常了。可能是自己的問題,也沒去深究。反正能用就OK了。我把編譯出來的ncnn.lib ncnn.a和linux下的onnx2ncnn工具都放在了我的網盤裡,不想被編譯折磨的就直接去下吧。如果編譯遇見問題,也可以給我留言,哈哈~

說明:ncnnd.lib是windows下的debug版本,ncnn.lib是release版本,libncnn.a是linux下的庫檔案,onnx2ncnn是linux下的轉換工具。

下載下傳位址:NCNN 提取碼:6cuc

四、C++調用和封裝

說明

對于vs中lib庫和include目錄的配置就不贅述了,有不懂的之前的文章有提過,假定工程已經配置完成。大體的調用過程NCNN和MNN都差不多,先加載模型建立一個指向模型的指針,然後建立session、建立用于處理輸入的tensor,将input_tensor送入session,運作session,最後得到網絡的輸出。如果對C++比較熟悉的話,看着官網的教程比葫蘆畫瓢即可,隻有一個地方需要說明就是對輸出的獲得。先看下我的代碼和官網的代碼再說為什麼

我的輸出:

// run net and get output
	ncnn::Mat out, out1;
	ret = ex.extract("output0", out);
	ex.extract("376", out1);
           

官網的例子輸出:

ncnn::Mat out;
    ex.extract("detection_out", out);
           

辣麼問題來了,我的"output0"和"376"、官網的“detection_out”都哪裡來的?有兩個地方可以得到,最簡單的方法,使用MNN架構下的轉換工具,在轉換完成的時候會給出模型的輸入和輸出名稱,直接拷貝即可

>MNNConvert.exe -f ONNX --modelFile model.onnx --MNNModel slime.mnn  --bizCode biz

MNNConverter Version: 0.2.1.5git - MNN @ 2018


Start to Convert Other Model Format To MNN Model...
[17:49:58] :29: ONNX Model ir version: 6
Start to Optimize the MNN Net...
[17:49:58] :20: Inputs: input0
[17:49:58] :37: Outputs: output0, Type = Concat
[17:49:58] :37: Outputs: 376, Type = Softmax
Converted Done!
           

如果沒有MNN的轉換工具,在後面加載模型後單步跟一下,在Net = new ncnn::Net()變量中有個blob變量,在記憶體中檢視一下,裡面存的有模型的各個層的名稱。代碼中的img_w,img_h就是在第二步轉換的時候你指定的w,h。這裡隻寫了核心調用函數,具體使用時還請自行添加一些輔助函數!

C++代碼

detection.h

#pragma once

#include <opencv2/opencv.hpp>
#include <string>
#include <stack>
#include "net.h"
#include <stdio.h>
#include <algorithm>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/opencv.hpp>
#include <fstream>
#include "omp.h"


struct bbox {
	float x1;
	float y1;
	float x2;
	float y2;
	float s;
};

struct box {
	float cx;
	float cy;
	float sx;
	float sy;
};

struct ObjectInfo {
	float x1; //bbox的left
	float y1; //bbox的top
	float x2; //bbox的right
	float y2; //bbox的bottom
	float prob; //置信度
};


class ObjectDetection
{
private:
	float _nms = 0.4;
	float _threshold = 0.6;
	const float mean_vals[3] = { 104.f, 117.f, 123.f };
	const float norm_vals[3] = { 1.0 / 104.0, 1.0 / 117.0, 1.0 / 123.0 };
	cv::Mat img;
	ncnn::Net *Net;
	int img_w = 300;
	int img_h = 300;
	int numThread;
	int detect_count = 0;
	static inline bool cmp(bbox a, bbox b);
public:
	ObjectDetection(std::string modelFolder, int num_thread);
	~ObjectDetection();
	int Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj);
};
           

detection.cpp

#include "Detection.h"
#include <cmath>

ObjectDetection::ObjectDetection(std::string modelFolder, int num_thread)
{
	Net = new ncnn::Net();
	std::string model_param = modelFolder + "Detect.param";
	std::string model_bin = modelFolder + "Detect.bin";
	int ret = Net->load_param(model_param.c_str());
	ret = Net->load_model(model_bin.c_str());	
	numThread = num_thread;
}

ObjectDetection::~ObjectDetection()
{
	if (Net != nullptr)
	{
		delete Net;
		Net = nullptr;
	}
}

int ObjectDetection::Detect(unsigned char *inputImage, int inputw, int inputh, std::vector<ObjectInfo > &obj)
{
	int ret = -1;	
	ncnn::Mat in = ncnn::Mat::from_pixels_resize(inputImage, ncnn::Mat::PIXEL_BGR, inputw, inputh, img_w, img_h);
	in.substract_mean_normalize(mean_vals, norm_vals);	
	ncnn::Extractor ex = Net->create_extractor();		
	ex.set_light_mode(true);
	ret = ex.input("input0", in);	
	// run net and get output
	ncnn::Mat out, out1;
	// bbox的輸出
	ret = ex.extract("output0", out);
	ex.extract("376", out1);

	// get result	
	for (int i = 0; i < out.h; ++i)
	{
		// 得到網絡的具體輸出
		const float *boxes = out.row(i);
		const float *scores = out1.row(i);
		// 執行你自己的操作
	}
	std::sort(total_box.begin(), total_box.end(), cmp);
	NMS(total_box, _nms);
	return 0;
}
           

五、 編寫JNI C++

在Android Studio中配置NDK,具體配置網上有很多教程我就不啰嗦了,假定android strdio的jni c++環境已經配置完成。源碼中的函數名的格式是jni c++要求的,必須這種格式,根據實際情況修改,函數名中的"com_example_demokit_Detection"對應到java的應用中就是"com.example.demokit.Detection"這樣就很好了解了。

native-lib.cpp

#include <jni.h>
#include <string>

#include "Detection.h"
#include <android/log.h>

extern "C" JNIEXPORT jlong JNICALL
Java_com_example_demokit_Detection_Create(JNIEnv *env, jobject instance, jstring path) {
    char* _path;
    _path = (char*)env->GetStringUTFChars(path,0);
    Detection *phandle = new Detection(_path, 2);
    return (jlong)phandle;
}

extern "C" JNIEXPORT jintArray JNICALL
Java_com_example_demokit_Detection_Detect(JNIEnv *env, jobject instance, jlong handle, jint campos, jint w, jint h, jbyteArray data_) {
    Detection *gp = NULL;
    if (handle)
        gp = (Detection *)handle;
    else
        return nullptr;
    jbyte *data = env->GetByteArrayElements(data_, NULL);

    std::vector<ObjectInfo> objects;
    gp->Detect((unsigned char*)data, w, h, objects);
    env->ReleaseByteArrayElements(data_, data, 0);
    jintArray jarr = env->NewIntArray(objects.size()*15+1);
    jint *arr = env->GetIntArrayElements(jarr, NULL);
    arr[0] = objects.size();
    for (int i = 0; i < objects.size(); i++)
    {
        arr[i*5 + 1] = objects[i].x1;
        arr[i*5 + 2] = objects[i].y1;
        arr[i*5 + 3] = objects[i].x2;
        arr[i*5 + 4] = objects[i].y2;
        arr[i*5 + 5] = objects[i].prob;
    }
    env->ReleaseIntArrayElements(jarr, arr, 0);
    return jarr;
}
           

六、java調用

package com.example.demokit;

public class Detection {
    static {
        System.loadLibrary("native-lib");
    }
    private long handle;
    public Detection(String path){
        handle = Create(path);
    }
    public int[] Detect(int w, int h, byte[] data){
            return  Detect(handle, w, h, data);
    }
    private native long Create(String path);
    private native int[] Detect(long handle, int w, int h, byte[] data);
}
           

七、 應用層使用

在應用層就可以直接調用上面的java類啦,搞定~

package com.example.demokit;
import android.graphics.Point;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class DetectTool {
    private Detection mDetection;
    private static final int DATA_LENGTH = 5; // 矩形框坐标2個,每個具有x,y兩個值;置信度1個;

    public DetectTool(String dect_model_dir){
        /**
         * @dect_model_dir: 檢測模型所在的目錄路徑
         */
        mDetection = new Detection(dect_model_dir);
    }

    private ObjectInfo ArrayAnalysis(int[] src_array){
        /**
         * 對輸入的數組進行解析,傳回ObjectInfo對象
         * @src_array: 具有DATA_LENGTH所示結構的數組
         */
        ObjectInfo obj_info = new ObjectInfo();

        Point[] pointFaceBox = new Point[2];
        // face_bbox 坐标
        for(int i = 0; i < 2; i++) {
            Point point = new Point();
            point.x = src_array[2*i];
            point.y = src_array[2*i+1];
            pointFaceBox[i] = point;
        }
        // 置信度
        obj_info.setProb(src_array[4]);
        return obj_info ;
    }

    public List<ObjectInfo> GetObjectInfo(int width, int height, byte[] data){
        /**
         * @width:圖檔寬度
         * @height:圖檔高度
         * @data:圖檔的位元組流
         */
        int[] obj= mDetection.Detect(width, height, data);
        List<ObjectInfo> obj_list = new ArrayList<>();
        int obj_count = obj[0];
        for(int i = 0; i < obj_count ; i++){
            int[] obj_array = Arrays.copyOfRange(obj, i*DATA_LENGTH + 1, (i + 1) * DATA_LENGTH+1);
            ObjectInfo obj_info = this.ArrayAnalysis(obj_array);
            obj_list.add(obj_info);
        }
        return obj_list;
    }
}
           

繼續閱讀