之前用torchvision裡面的官方模型做了腳本的測試,今天手賤,沒忍住,用onnx官方的models做了驗證,于是帶來了新的血淚。
onnx/modelsgithub.com

gist更新了腳本,增加了驗證代碼的檔案onnx2pytorch_validate.py:
https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8gist.github.com
腳本用血淚更新後,基本起碼可以支援以下onnx/models中的幾個模型了:
- googlenet
- resnet18v2
- resnet34v2
- squeezenet1.1
- mobilenetv2-1.0
- alex_net
- densenet121
- vgg16
滿足下面兩點的模型誤差會顯著
減少:
- 現代一點的架構:基本就是加residual
- 去掉Softmax的導出(可能是Softmax的各自實作略有不同?)
- 如果直接導出為raw_data,也就是是二進制bytes格式,可能精度丢失的會更少
當然,
還有一個更大的可能,我某個轉換代碼有問題。
誤差測試結果
誤差測試方法:使用MXNet導入onnx VS 腳本導入模型,輸出以下内容:
print
運作結果: desenet121的結果沒有flatten,太惡心了,不要怪我,轉出來的模型就這樣。
================================================================================
googlenet.onnx :
Input Blob Names: ['data_0']
<Symbol softmax0>
0.038564444 0.00010830386 1.1318343e-07
[0.00047492 0.00078218 0.00057791 0.00075334 0.00317502]
[0.0004738 0.00075904 0.0004524 0.00064938 0.0024273 ]
================================================================================
resnet18v2.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add1>
4.053116e-06 8.2214177e-07 0.0
[-2.7655013 0.78636014 0.33238387 0.44985968 0.8458633 ]
[-2.7655027 0.7863597 0.33238354 0.44986013 0.8458635 ]
================================================================================
resnet34v2.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add2>
4.2915344e-06 1.0785647e-06 0.0
[-2.2723439 0.04653598 0.30353007 0.2595243 0.41742972]
[-2.2723436 0.04653586 0.30352995 0.2595226 0.41742748]
================================================================================
squeezenet1.1.onnx :
Input Blob Names: ['data']
<Symbol squeezenet0_flatten0_reshape0>
5.2452087e-06 7.539566e-07 0.0
[0.5420265 4.364538 4.056094 4.889414 4.368317 ]
[0.5420265 4.364538 4.056096 4.889414 4.368317 ]
================================================================================
mobilenetv2-1.0.onnx :
Input Blob Names: ['data']
<Symbol mobilenetv20_output_flatten0_reshape0>
7.390976e-06 1.855201e-06 0.0
[-3.0427516 1.1981574 0.29107922 -0.53182554 1.0944995 ]
[-3.04275 1.198158 0.29107827 -0.5318265 1.0945004 ]
================================================================================
alex_net.onnx :
Input Blob Names: ['data_0']
<Symbol softmax1>
0.00015634485 9.566847e-06 7.887138e-09
[0.00150446 0.00085357 0.00193772 0.00269155 0.00280163]
[0.0015188 0.0008455 0.00191552 0.00266171 0.0027587 ]
================================================================================
densenet121.onnx :
Input Blob Names: ['data_0']
<Symbol convolution318>
5.00679e-06 8.735047e-07 0.0
[[[-1.742769 ]]
[[-0.05456648]]
[[ 1.0543407 ]]
[[ 0.522862 ]]
[[ 0.82946616]]]
[[[-1.7427675 ]]
[[-0.05456671]]
[[ 1.054343 ]]
[[ 0.5228648 ]]
[[ 0.82946706]]]
================================================================================
vgg16.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add129>
1.1920929e-06 2.18953e-07 0.0
[-1.8850654 1.1514534 -0.01803813 0.7195613 -0.06146571]
[-1.8850652 1.1514536 -0.01803814 0.7195611 -0.06146575]
之前的記錄連結:
Captain Jack:一個ONNX模型轉Pytorch的玩具腳本zhuanlan.zhihu.com