天天看點

[源碼解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

上文我們看到了AutogradMetadata,DistAutogradContainer 和 DistAutogradContext 等一系列基礎類。我們知道了分布式autograd如何基于RPC進行傳遞,如何在節點之間互動,節點如何區分維護這些Session。本文繼續分析,主要目的是看看反向傳播如何切入到引擎之中。

目錄

[源碼解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

0x00 摘要

0x01 前文回憶

0x02 計算圖

2.1 普通示例

2.2 分布式示例

2.3 分布式注釋版

0x03 反向傳播

3.1 發起反向傳播

3.1.1 外部主動發起

3.1.1.1 示例

3.1.1.2 C++世界

3.1.2 内部隐式發起

3.1.2.1 BACKWARD_AUTOGRAD_REQ

3.1.2.2 PropagateGradientsReq

3.2 接受反向傳播

3.2.1 接受消息

3.2.2 處理消息

3.3 總結

0xFF 參考

PyTorch分布式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[源碼解析]深度學習利器之自動微分(3) --- 示例解讀

[源碼解析]PyTorch如何實作前向傳播(1) --- 基礎類(上)

[源碼解析]PyTorch如何實作前向傳播(2) --- 基礎類(下)

[源碼解析] PyTorch如何實作前向傳播(3) --- 具體實作

[源碼解析] Pytorch 如何實作後向傳播 (1)---- 調用引擎

[源碼解析] Pytorch 如何實作後向傳播 (2)---- 引擎靜态結構

[源碼解析] Pytorch 如何實作後向傳播 (3)---- 引擎動态邏輯

[源碼解析] PyTorch 如何實作後向傳播 (4)---- 具體算法

[源碼解析] PyTorch 分布式(1)------曆史和概述

[源碼解析] PyTorch 分布式(2) ----- DataParallel(上)

[源碼解析] PyTorch 分布式(3) ----- DataParallel(下)

[源碼解析] PyTorch 分布式(4)------分布式應用基礎概念

[源碼解析] PyTorch分布式(5) ------ DistributedDataParallel 總述&如何使用

[源碼解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store

[源碼解析] PyTorch 分布式(7) ----- DistributedDataParallel 之程序組

[源碼解析] PyTorch 分布式(8) -------- DistributedDataParallel之論文篇

[源碼解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化

[源碼解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer靜态架構

[源碼解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 建構Reducer和Join操作

[源碼解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向傳播

[源碼解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向傳播

[源碼解析] PyTorch 分布式 Autograd (1) ---- 設計

[源碼解析] PyTorch 分布式 Autograd (2) ---- RPC基礎

[源碼解析] PyTorch 分布式 Autograd (3) ---- 上下文相關

為了更好的說明,本文代碼會依據具體情況來進行相應精簡。

我們回憶一下前面幾篇文章的内容。

首先,對于分布式 autograd,我們需要在前向傳播期間跟蹤所有 RPC,以確定正确執行後向傳播。為此,當執行 RPC 時候,我們把 <code>send</code>和<code>recv</code> functions 附加到autograd圖之上。

該<code>send</code>函數附加到 RPC 的發起源節點之上,其輸出邊指向 RPC 輸入張量的 autograd 函數。在向後傳播期間,<code>send</code>函數的輸入是從目标接收的,是對應<code>recv</code>函數的輸出。

該<code>recv</code>函數附加到 RPC 的接受目标節點之上,其輸入從某些運算符得到,這些運算符使用輸入張量在RPC接受目标上執行。在後向傳播期間,<code>recv</code>函數的輸出梯度将被發送到源節點之上,并且作為<code>send</code>方法的輸入。

每<code>send-recv</code>對被配置設定一個全局唯一的<code>autograd_message_id</code> 以唯一地辨別該<code>send-recv</code>對。這對于在向後傳播期間查找遠端節點上的相應函數很有用。

對于RRef,每當我們調用<code>torch.distributed.rpc.RRef.to_here()</code> 時,我們都為涉及的張量添加了一個适當的<code>send-recv</code>對。

其次,在前向傳播的具體代碼之中,我們在上下文中存儲每個 autograd 傳播的<code>send</code>和<code>recv</code>函數。這確定我們在 autograd 圖中儲存對适當節點的引用以使其保持活動狀态。除此之外,這也使得在向後傳播期間很容易查找到對應的<code>send</code>和<code>recv</code>函數。

再次,以下是 torch/csrc/distributed/rpc/message.h 之中的部分消息定義:

在前文,我們看到了 FORWARD_AUTOGRAD_REQ 在前向傳播之中如何調用,假設如下代碼:rpc.rpc_sync("worker1", torch.add, args=(t1, t2)),其調用序列是:

rpc_sync 調用 _invoke_rpc。

_invoke_rpc 調用 _invoke_rpc_builtin。

然後調用到 pyRpcBuiltin,繼而調用到 sendMessageWithAutograd。

sendMessageWithAutograd 内部會建構 FORWARD_AUTOGRAD_REQ消息,最後使用RPC 發送。

至此,關于整體流程,我們就有了幾個疑問:

在反向計算圖的起始位置,如何發起反向傳播,怎麼傳遞給反向傳播的下一個環節?

在反向傳播的内部環節,BACKWARD_AUTOGRAD_REQ 是何時調用?recv 操作是何時被調用? 在上下文中,recvAutogradFunctions_ 是在哪裡設定的?

以上兩個環節分别如何進入分布式autograd引擎?

我們接下來就圍繞這些疑問進行分析,核心就是如何進入 dist.autograd 引擎。

我們首先從計算圖來通過幾個示例來看看。

首先看看普通計算,這個是 dist.auto 官方圖例的本地版本。可以看到是由 AddBackward0,AccumulateGrad 和 MulBackward0 等組成了計算圖。

具體對應如下圖:

[源碼解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

接下來看看分布式的例子,這個例子就是官方設計中圖例大緻對應的代碼,我們把 torch.mul(t3, t4) 命名為 t5,加入了 loss。

在分布式之下,t3 是異地運作。

t5 對應的是 mul,t5.grad_fn 是 &lt;MulBackward0 object at 0x7fbf18d297b8&gt;。

t3.grad_fn 是 &lt;CppFunction object at 0x7fbf18d11a20&gt;,就是說,recv 對應的就是 CppFunction 。

loss 是 tensor(5.5680, grad_fn=)。

其餘的都是 None。

我們把設計圖例再展示出來,上面示例代碼就是下圖的左側 worker 0,t3 實際就是運作在 worker 1,大家可以看到分布式上下文中的一些特點。

[源碼解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

為了更好的說明,我們列印了一些log作為注釋。

列印結果是:

加上分布式相關算子之後,圖例如下:

[源碼解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

我們接下來要看看如何進入dist autograd 引擎,結合我們圖例,就是:

worker 0 如何主動發起反向傳播,然後進入分布式引擎?

woker 0 在内部如何發起對 worker 1 的反向傳播請求?

worker 1 如何被動接受反向傳播消息,然後進入分布式引擎?

我們找一找如何發起反向傳播,按照從下往上的順序進行。這裡也有兩種:

一種是主動發起,比如上圖之中 worker 0 的 loss 之上主動調用backward 方法。

一種是内部隐式發起,比如上圖的 worker 0 之中的 t3 如何通過 recv 告訴 worker 1,你應該啟動反向傳播了。

我們從上往下看分布式 autograd 的 backward 如何主動調用,比如在示例之中會顯示調用。

在 <code>torch/_C/_distributed_autograd.pyi</code> 之中我們可以看到如下注釋:

是以我們去torch/csrc/distributed/autograd/init.cpp檔案中看看。

省略了部分代碼,這裡能看到生成了上下文,定義了 backward,get_gradients等等。

具體 backward 定義在 torch/csrc/distributed/autograd/autograd.cpp。

可以看到,最終會調用到 DistEngine::getInstance().execute(context_id, roots, retain_graph) 完成反向傳播。這就進入了引擎。

因為是隐式發起,是以代碼比較隐蔽,我們這次采用從下至上的方式來剝絲抽繭。我們知道,如果節點之間要求反向傳播,會發送BACKWARD_AUTOGRAD_REQ,是以我們從 BACKWARD_AUTOGRAD_REQ 開始發起尋找。

在 torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp 之中 PropagateGradientsReq::toMessageImpl 會調用到 BACKWARD_AUTOGRAD_REQ。

繼續找誰發出來的 BACKWARD_AUTOGRAD_REQ,就是誰調用到了 toMessageImpl?原來在 torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp 這裡建構了 PropagateGradientsReq,會使用 toMessage 來建構一個消息。即,RecvRpcBackward 的調用會發送 BACKWARD_AUTOGRAD_REQ。

是以我們知道,在 RecvRpcBackward 的執行時候,會發送 BACKWARD_AUTOGRAD_REQ,發送給下一個節點。具體哪裡調用 RecvRpcBackward?我們會在下一篇 DistEngine 之中介紹。

此時具體如下,對應就是 worker 0 的 t3 給 worker 1 發送 BACKWARD_AUTOGRAD_REQ 消息。

對應示例圖就是:

[源碼解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

我們接下來看看接收方如何處理反向傳播,我們再次回到 worker 1,就是圖上的 send 節點如何接受反向傳播消息。

在生成 TensorPipeAgent 時候,把 RequestCallbackImpl 配置為回調函數。這是 agent 的統一響應函數。前面關于代理接收邏輯時候,我們也提到了,會進入 RequestCallbackNoPython::processRpc 函數。其中可以看到有對 BACKWARD_AUTOGRAD_REQ 的處理邏輯。

這種是 RPC 的正常流程。

在 processBackwardAutogradReq 之中會:

擷取 DistAutogradContainer。

擷取 上下文,該上下文是之前在前向傳播過程之中建立的,從前文可知,本圖例之中,worker 0 和 worker 1之中每個 autograd 傳播都共享同一個上下文 context id。

通過發送方的 context id,從上下文之中擷取到對應的 SendRpcBackward。這裡我們看到了上下文是如何使用。

使用 sendFunction 作為參數,調用 executeSendFunctionAsync 進行引擎處理。

在 worker 1 的 DistEngine::executeSendFunctionAsync 内部,會進行輾轉處理,最終發送 BACKWARD_AUTOGRAD_REQ 到其反向傳播的下遊,是以我們繼續在示例圖之上修改拓展,增加一個 BACKWARD_AUTOGRAD_REQ。

[源碼解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎

我們可以看到有兩個途徑進入 dist autograd 引擎,啟動反向傳播:

一個是示例代碼顯式主動調用 backward,進而調用到 DistEngine::getInstance().execute,就是 worker 0。

一個是被動調用 DistEngine::getInstance().executeSendFunctionAsync,就是 worker 1(當然,worker 0 的 send 也對應了一個被動調用)。

現在從上至下/自下而上兩種查找反向傳播的發起源頭,都歸結到了 DistEngine,是以我們下一篇就介紹 DistEngine。