天天看點

我不會用 Triton 系列:如何實作一個 backend

如何實作一個 backend

這篇文章主要講如何實作一個 Triton Backend,以 Pytorch Backend 為例子。

Backend API

我們需要實作兩個類來存儲狀态以及七個 Backend API。

ModelState

ModelInstanceState

TRITONBACKEND_Initialize

TRITONBACKEND_Finalize

TRITONBACKEND_ModelInitialize

TRITONBACKEND_ModelFinalize

TRITONBACKEND_ModelInstanceInitialize

TRITONBACKEND_ModelInstanceFinalize

TRITONBACKEND_ModelInstanceExecute

ModelState 和 ModelInstanceState 這兩個類可以綁定到 Triton 提供的指針上,你可以在 ModelState 和 ModelInstanceState 裡面存儲任何你想要的狀态,然後将它綁定到 Triton 提供的指針上。這兩個類并非必須的,它的作用相當于存儲。在閱讀了 Pytorch Backend 之後,會發現如果要寫新的 Backend,七個 Backend API 并不需要做任何更改,隻需要修改 ModelState 和 ModelInstanceState 即可。這兩個類裡面隻需要做幾個事情:模型配置檔案檢驗、處理請求。

簡單概括就是:

動态連結庫加載的時候,執行 TRITONBACKEND_Initialize

當屬于一個 Backend 的模型都被删除了,且 Triton 開啟了熱更新,它會解除安裝動态連結庫,執行 TRITONBACKEND_Finalize

其他的方法看名字就好了,不然就說了很多廢話。比如 “在模型初始化的時候,調用模型初始化”

一個 Backend 對應多個 Model,Backend 隻調用一次,Model 調用次數和倉庫中模型數量一樣多

一個 Model 對應多個 ModelInstance,根據模型的配置檔案,調用 “模型執行個體” 的初始化方法。

Pytorch Backend 例子

位址:https://github.com/triton-inference-server/pytorch_backend/blob/main/src/libtorch.cc

我們以 Pytorch Backend 為學習例子,看看應該如何實作。

一個 ModelState 和一個 TRITONBACKEND_Model 相關聯,這個類主要提供一些模型配置檢查、參數校驗、模型執行個體共用的屬性和方法。比如,加載模型的方法是所有模型執行個體初始化的時候需要的。

一個 ModelInstanceState 和一個 TRITONBACKEND_ModelInstance 相關聯,多個 ModelInstanceState 共享一個 ModelState。這個類主要提供一些處理請求、前向傳播執行的方法。

令人頗感困惑的是,Pytorch 将模型輸入輸出配置的檢查放到了這個函數裡面,而 Tensorflow 的 backend 實作中,将模型輸入輸出的檢查放到 ModelState。從抽象的分層來看,我認為模型配置的檢查應該放到執行個體化之前,這樣就可以避免每次初始化 “模型執行個體” 的時候都檢查一次。Pytorch 這麼做的原因是,設計了一個 ModelInstanceState 相關的内部狀态 <code>input_index_map_</code>,這個狀态的初始化依賴于模型的配置。

Pytorch Backend 裡面沒有什麼需要特别處理的東西,就是模闆代碼就好了,列印 backend 名字和版本之類的。

沒有提供實作。解除安裝動态連結庫,直接移除就好了,沒有需要清理的東西。

調用 Create 方法建立一個 ModelState,使用 TRITONBACKEND_ModelSetState 将 ModelState 綁定到傳進來的 TRITONBACKEND_Model 上面。

前面綁定的是一個指針,是以要在這裡删除指針。

“模型執行個體” 初始化和 TRITONBACKEND_ModelInitialize 的邏輯基本一緻。不過需要使用多幾個 API,這個方法傳進來隻有模型執行個體,我們可以從執行個體拿到綁定的 Model,再從 Model 拿出 ModelState,然後調用 “模型執行個體” 的 Create 方法進行初始化,最後同樣調用 API 綁定到 ModelInstance。

這個 API 的輸入是 “模型執行個體” 和 “請求”,這裡從 “模型執行個體” 中取出 ModelInstanceState,然後調用處理請求的方法即可。

實作細節

在 Pytorch 的實作中,将模型配置檔案的檢驗放到了 “模型執行個體” 初始化的時候,因為它設計了一些 “模型執行個體” 相關的狀态,并且需要使用到模型配置檔案。于是它一邊進行模型配置檔案的檢驗,一邊初始化 “模型執行個體” 相關的狀态。

在 OneFlow 的實作中,計劃将模型配置檔案的校驗放到模型初始化 <code>TRITONBACKEND_ModelInitialize</code> 裡面,而不是 “模型執行個體” 初始化的時候。不過,這取決于後面的 OneFlow C++ API 是如何實作的。

那麼 Pytorch Backend 是如何實作的呢?

整個檢驗過程是:進行模型配置檔案的檢驗,之後設定 “模型執行個體” 相關的狀态。一邊分析輸入輸出的名字是否符合規則,一邊将輸入輸出的名字映射到 id。這麼設計的 主要原因 是 forward 接口需要使用者按照一定順序将 tensor 輸入,并沒有提供一個 map 結構的輸入。寫成這樣的 次要原因 是提高效率,一次 parse 就好。

關鍵資料結構:

關鍵函數調用:

如果 Triton 開啟了 GPU,那麼需要做一些特别的處理。比如在同步 CudaStream。

TRITONBACKEND_ModelInstanceExecute 拿到的是一個二維指針,即一個請求的數組,這一批請求需要一起做處理。在進行前向傳播之前,我們需要将輸出收集起來。Triton 提供了一些工具類幫助我們去做收集,估計下面這個收集的方法是一個異步的方法,這樣可以提高性能,不過需要我們顯式使用同步操作。下面的 <code>input_buffer</code> 是一個指針,可能指向 Host,也可能指向 Device,不管這個指針指向哪裡,後面使用 libtorch 的方法,建立一個 tensor,這樣我們就擷取了輸入了。

需要注意的是:配置設定記憶體需要調用 Triton 的方法,然後用 torch 建立 tensor。不管 Tensor 所屬的記憶體是 CPU 還是 GPU 的,都是由 Triton 來管理。

關鍵資料結構

關鍵 API 調用

在前處傳播完了之後,就可以擷取到輸出的 Tensor 了,我們隻需要從 Tensor 中取出資料指針就可以了,然後調用 Triton 提供的工具,幫我們将資料拷貝到指定的 repsonse 上面。

時間戳宏

為了友善擷取時間戳,Triton 提供了一個宏函數,友善調用。

需要哪些時間戳

統計資料

總結

簡單梳理了一下,其實就幾個事情:

模型配置檔案檢測,驗證輸入輸出的寫法是否正确

請求和響應,從請求中取出輸入,将輸出寫到響應

記憶體管理的方法,Triton 管理記憶體,Pytorch 則接收或輸出一個指針

統計資料,擷取幾個時間戳,調用 Triton API 來設定

深度學習架構在上面的過程中,隻負責了一小部分,從 Triton 拿到指針,變成一個架構可以處理的 Tensor,然後進行推理,擷取輸出,最後将輸出變成一個指針,傳回給 Triton。于是,Triton 拿到指針之後,寫到 response 裡面。