天天看點

Pytorch之TorchScript和trace儲存網路的差別

TorchScript和trace的差別
model = MyModule()
script = torch.jit.trace(model, inputs) # 适合固定形式的控制流(也就是沒有if else)
script = torch.jit.script(model) #适合動态控制流(有if else)
           

可以觀察兩者接收的參數,會發現兩者都需要傳入model,但是trace還需要喂入示例資料,這樣才能觀察到整個網路,而後者是不需要的。

  • torch.jit.trace() 适合固定形式的控制流(也就是沒有if else)
  • torch.jit.script() 适合動态控制流(有if else)

接着把兩者列印出來:

print(script.code)
           

你會發現居然還原了代碼,這樣做的好處是友善于生産部署,得到script後我們就可以序列化儲存為檔案:

model.save("final_model.pt")
           

然後就可以脫離Python,在C++上加載運作了。

繼續閱讀