天天看點

pytorch轉onnx驗證_ONNX轉PyTorch腳本驗證與更新

之前用torchvision裡面的官方模型做了腳本的測試,今天手賤,沒忍住,用onnx官方的models做了驗證,于是帶來了新的血淚。

onnx/models​github.com

pytorch轉onnx驗證_ONNX轉PyTorch腳本驗證與更新

gist更新了腳本,增加了驗證代碼的檔案onnx2pytorch_validate.py:

https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8​gist.github.com

腳本用血淚更新後,基本起碼可以支援以下onnx/models中的幾個模型了:

  • googlenet
  • resnet18v2
  • resnet34v2
  • squeezenet1.1
  • mobilenetv2-1.0
  • alex_net
  • densenet121
  • vgg16
重模組化型後誤差較小的特征:

滿足下面兩點的模型誤差會顯著

減少

  1. 現代一點的架構:基本就是加residual
  2. 去掉Softmax的導出(可能是Softmax的各自實作略有不同?)
  3. 如果直接導出為raw_data,也就是是二進制bytes格式,可能精度丢失的會更少

當然,

還有一個更大的可能,我某個轉換代碼有問題

誤差測試結果

誤差測試方法:

使用MXNet導入onnx VS 腳本導入模型,輸出以下内容:

print           
運作結果:

desenet121的結果沒有flatten,太惡心了,不要怪我,轉出來的模型就這樣。

================================================================================
googlenet.onnx :
Input Blob Names: ['data_0']
<Symbol softmax0>
0.038564444 0.00010830386 1.1318343e-07
[0.00047492 0.00078218 0.00057791 0.00075334 0.00317502]
[0.0004738  0.00075904 0.0004524  0.00064938 0.0024273 ]
================================================================================
resnet18v2.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add1>
4.053116e-06 8.2214177e-07 0.0
[-2.7655013   0.78636014  0.33238387  0.44985968  0.8458633 ]
[-2.7655027   0.7863597   0.33238354  0.44986013  0.8458635 ]
================================================================================
resnet34v2.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add2>
4.2915344e-06 1.0785647e-06 0.0
[-2.2723439   0.04653598  0.30353007  0.2595243   0.41742972]
[-2.2723436   0.04653586  0.30352995  0.2595226   0.41742748]
================================================================================
squeezenet1.1.onnx :
Input Blob Names: ['data']
<Symbol squeezenet0_flatten0_reshape0>
5.2452087e-06 7.539566e-07 0.0
[0.5420265 4.364538  4.056094  4.889414  4.368317 ]
[0.5420265 4.364538  4.056096  4.889414  4.368317 ]
================================================================================
mobilenetv2-1.0.onnx :
Input Blob Names: ['data']
<Symbol mobilenetv20_output_flatten0_reshape0>
7.390976e-06 1.855201e-06 0.0
[-3.0427516   1.1981574   0.29107922 -0.53182554  1.0944995 ]
[-3.04275     1.198158    0.29107827 -0.5318265   1.0945004 ]
================================================================================
alex_net.onnx :
Input Blob Names: ['data_0']
<Symbol softmax1>
0.00015634485 9.566847e-06 7.887138e-09
[0.00150446 0.00085357 0.00193772 0.00269155 0.00280163]
[0.0015188  0.0008455  0.00191552 0.00266171 0.0027587 ]
================================================================================
densenet121.onnx :
Input Blob Names: ['data_0']
<Symbol convolution318>
5.00679e-06 8.735047e-07 0.0
[[[-1.742769  ]]
 [[-0.05456648]]
 [[ 1.0543407 ]]
 [[ 0.522862  ]]
 [[ 0.82946616]]]
[[[-1.7427675 ]]
 [[-0.05456671]]
 [[ 1.054343  ]]
 [[ 0.5228648 ]]
 [[ 0.82946706]]]
================================================================================
vgg16.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add129>
1.1920929e-06 2.18953e-07 0.0
[-1.8850654   1.1514534  -0.01803813  0.7195613  -0.06146571]
[-1.8850652   1.1514536  -0.01803814  0.7195611  -0.06146575]
           

之前的記錄連結:

Captain Jack:一個ONNX模型轉Pytorch的玩具腳本​zhuanlan.zhihu.com

pytorch轉onnx驗證_ONNX轉PyTorch腳本驗證與更新