1 概述
TensorFlow後端分為四層,運作時層、計算層、通信層、裝置層。運作時作為第一層,實作了session管理、graph管理等很多重要的邏輯,是十分關鍵的一層。根據任務分布的不同,運作時又分為本地運作時和分布式運作時。本地運作時,所有任務運作于本地同一程序内。而分布式運作時,則允許任務運作在不同機器上。
Tensorflow的運作,通過session搭建了前後端溝通的橋梁,前端幾乎所有操作都是通過session進行。session的生命周期由建立、運作、關閉、銷毀組成,前文已經詳細講述過。可以将session看做TensorFlow運作的載體。而TensorFlow運作的核心對象,則是計算圖Graph。它由計算算子和計算資料兩部分構成,可以完整描述整個計算内容。Graph的生命周期包括建構和傳遞、剪枝、分裂、執行等步驟,本文會詳細講解。了解TensorFlow的運作時,重點就是了解會話session和計算圖Graph。
本地運作時,client master和worker都在本地機器的同一程序内,均通過DirectSession類來描述。由于在同一程序内,三者間可以共享記憶體,通過DirectSession的相關函數實作調用。
client前端直接面向使用者,負責session的建立,計算圖Graph的構造。并通過session.run()将Graph序列化後傳遞給master。master收到後,先反序列化得到Graph,然後根據反向依賴關系,得到幾個最小依賴子圖,這一步稱為剪枝。之後master根據可運作的裝置情況,将子圖分裂到不同裝置上,進而可以并發執行,這一步稱為分裂。最後,由每個裝置上的worker并行執行分裂後的子圖,得到計算結果後傳回。
2 Graph建構和傳遞
session.run()開啟了後端Graph的建構和傳遞。在前文session生命周期的講解中,session.run()時會先調用_extend_graph()将要運作的Operation添加到Graph中,然後再啟動運作過程。extend_graph()會先将graph序列化,得到graph_def,然後調用後端的TF_ExtendGraph()方法。下面我們從 http://c_api.cc 中的TF_ExtendGraph()看起。
// 增加節點到graph中,proto為序列化後的graph
void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
size_t proto_len, TF_Status* status) {
GraphDef g;
// 先将proto轉換為GrapDef。graphDef是圖的序列化表示,反序列化在後面。
if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
status->status = InvalidArgument("Invalid GraphDef");
return;
}
// 再調用session的extend方法。根據建立的不同session類型,多态調用不同方法。
status->status = s->session->Extend(g);
}
後端系統根據生成的Session類型,多态的調用Extend方法。如果是本地session,則調用DirectSession的Extend()方法。下面看DirectSession的Extend()方法。
Status DirectSession::Extend(const GraphDef& graph) {
// 保證線程安全,然後調用ExtendLocked()
mutex_lock l(graph_def_lock_);
return ExtendLocked(graph);
}
// 主要任務就是建立GraphExecutionState對象。
Status DirectSession::ExtendLocked(const GraphDef& graph) {
bool already_initialized;
if (already_initialized) {
TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
// 建立GraphExecutionState
std::unique_ptr<GraphExecutionState> state;
TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
execution_state_.swap(state);
}
return Status::OK();
}
最終建立了GraphExecutionState對象。它主要工作有
- 負責将GraphDef反序列化為graph,進而構造出graph。在初始化方法InitBaseGraph()中
- 執行部分op編排工作,在初始化方法InitBaseGraph()中
Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
const GraphDef* graph_def = &original_graph_def_;
// graphDef反序列化得到graph
std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get()));
// 恢複有狀态的節點
RestoreStatefulNodes(new_graph.get());
// 構造優化器的選項 optimization_options
GraphOptimizationPassOptions optimization_options;
optimization_options.session_options = session_options_;
optimization_options.graph = &new_graph;
optimization_options.flib_def = flib_def_.get();
optimization_options.device_set = device_set_;
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
// plaer執行op編排
Placer placer(new_graph.get(), device_set_, session_options_);
TF_RETURN_IF_ERROR(placer.Run());
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_PLACEMENT, optimization_options));
// 報春狀态節點
SaveStatefulNodes(new_graph.get());
graph_ = new_graph.release();
return Status::OK();
}
構造Graph:反序列化GraphDef為Graph
由于client傳遞給master的是序列化後的計算圖,是以master需要先反序列化。通過ConvertGraphDefToGraph實作。代碼在 http://graph_constructor.cc 中,如下
Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g) {
ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
/*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
/*missing_unused_input_map_keys=*/nullptr);
}
編排OP
Operation編排的目的是,将op以最高效的方式,放在合适的硬體裝置上,進而最大限度的發揮硬體能力。通過Placer的run()方法進行,算法很複雜,在 http://placer.cc 中,我也看得不大懂,就不展開了。
3 Graph剪枝
反序列化建構好Graph,并進行了Operation編排後,master就開始對Graph剪枝了。剪枝就是根據Graph的輸入輸出清單,反向周遊全圖,找到幾個最小依賴的子圖,進而友善并行計算。
Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
std::unique_ptr<ClientGraph>* out) {
std::unique_ptr<Graph> ng;
Status s = OptimizeGraph(options, &ng);
if (!s.ok()) {
// 1 複制一份原始的Graph
ng.reset(new Graph(flib_def_.get()));
CopyGraph(*graph_, ng.get());
}
// 2 剪枝,根據輸入輸出feed fetch,對graph進行增加節點或删除節點等操作。通過RewriteGraphForExecution()方法
subgraph::RewriteGraphMetadata rewrite_metadata;
if (session_options_ == nullptr ||
!session_options_->config.graph_options().place_pruned_graph()) {
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
ng.get(), options.feed_endpoints, options.fetch_endpoints,
options.target_nodes, device_set_->client_device()->attributes(),
options.use_function_convention, &rewrite_metadata));
}
// 3 處理優化選項optimization_options
GraphOptimizationPassOptions optimization_options;
optimization_options.session_options = session_options_;
optimization_options.graph = &ng;
optimization_options.flib_def = flib.get();
optimization_options.device_set = device_set_;
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
// 4 複制一份ClientGraph
std::unique_ptr<ClientGraph> dense_copy(
new ClientGraph(std::move(flib), rewrite_metadata.feed_types,
rewrite_metadata.fetch_types));
CopyGraph(*ng, &dense_copy->graph);
*out = std::move(dense_copy);
return Status::OK();
}
剪枝的關鍵在RewriteGraphForExecution()方法中,在 http://subgraph.cc 檔案中。
Status RewriteGraphForExecution(
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
const gtl::ArraySlice<string>& fetch_outputs,
const gtl::ArraySlice<string>& target_node_names,
const DeviceAttributes& device_info, bool use_function_convention,
RewriteGraphMetadata* out_metadata) {
std::unordered_set<string> endpoints;
// 1 建構節點的name_index,進而快速索引節點。為FeedInputs,FetchOutputs等步驟所使用
NameIndex name_index;
name_index.reserve(g->num_nodes());
for (Node* n : g->nodes()) {
name_index[n->name()] = n;
}
// 2 FeedInputs,添加輸入節點
if (!fed_outputs.empty()) {
FeedInputs(g, device_info, fed_outputs, use_function_convention, &name_index, &out_metadata->feed_types);
}
// 3 FetchOutputs,添加輸出節點
std::vector<Node*> fetch_nodes;
if (!fetch_outputs.empty()) {
FetchOutputs(g, device_info, fetch_outputs, use_function_convention, &name_index, &fetch_nodes, &out_metadata->fetch_types);
}
// 4 剪枝,形成若幹最小依賴子圖
if (!fetch_nodes.empty() || !target_node_names.empty()) {
PruneForTargets(g, name_index, fetch_nodes, target_node_names);
}
return Status::OK();
}
主要有4步
- 建構節點的name_index,進而快速索引節點。為FeedInputs,FetchOutputs等步驟所使用
- FeedInputs,添加輸入節點。輸入節點的資料來源于session.run()時的feed清單。
- FetchOutputs,添加輸出節點。輸出節點在session.run()時通過fetches所給出
- 剪枝PruneForTargets,形成若幹最小依賴子圖。這是剪枝算法最關鍵的一步。
PruneForTargets()從輸出節點反向搜尋,按照BFS廣度優先算法,找到若幹個最小依賴子圖。
static Status PruneForTargets(Graph* g, const subgraph::NameIndex& name_index,
const std::vector<Node*>& fetch_nodes,
const gtl::ArraySlice<string>& target_nodes) {
string not_found;
std::unordered_set<const Node*> targets;
// 1 AddNodeToTargets添加節點到targets中,從輸出節點按照BFS反向周遊。
for (Node* n : fetch_nodes) {
AddNodeToTargets(n->name(), name_index, &targets);
}
// 2 剪枝,得到多個最小依賴子圖子圖
PruneForReverseReachability(g, targets);
// 修正Source和Sink節點的依賴邊,将沒有輸出邊的節點連接配接到sink node上
FixupSourceAndSinkEdges(g);
return Status::OK();
}
主要有3步
- AddNodeToTargets,從輸出節點按照BFS反向周遊圖的節點,添加到targets中。
- PruneForReverseReachability,剪枝,得到多個最小依賴子圖子圖
- FixupSourceAndSinkEdges,修正Source和Sink節點的依賴邊,将沒有輸出邊的節點連接配接到sink node上
PruneForReverseReachability()在 http://algorithm.cc 檔案中,算法就不分析了,總體是按照BFS廣度優先算法搜尋的。
bool PruneForReverseReachability(Graph* g,
std::unordered_set<const Node*> visited) {
// 按照BFS廣度優先算法,從輸出節點開始,反向搜尋節點的依賴關系
std::deque<const Node*> queue;
for (const Node* n : visited) {
queue.push_back(n);
}
while (!queue.empty()) {
const Node* n = queue.front();
queue.pop_front();
for (const Node* in : n->in_nodes()) {
if (visited.insert(in).second) {
queue.push_back(in);
}
}
}
// 删除不在"visited"清單中的節點,說明最小依賴子圖不依賴此節點
std::vector<Node*> all_nodes;
all_nodes.reserve(g->num_nodes());
for (Node* n : g->nodes()) {
all_nodes.push_back(n);
}
bool any_removed = false;
for (Node* n : all_nodes) {
if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
g->RemoveNode(n);
any_removed = true;
}
}
return any_removed;
}
4 Graph分裂
剪枝完成後,master即得到了最小依賴子圖ClientGraph。然後根據本地機器的硬體裝置,以及op所指定的運作裝置等關系,将圖分裂為多個Partition Graph,傳遞到相關裝置的worker上,進而進行并行運算。這就是Graph的分裂。
Graph分裂的算法在 http://graph_partition.cc 的Partition()方法中。算法比較複雜,我們就不分析了。圖分裂有兩種
- splitbydevice 按裝置分裂,也就是将Graph分裂到本地各CPU GPU上。本地運作時隻使用按裝置分裂。
static string SplitByDevice(const Node* node) {
return node->assigned_device_name();
}
- splitByWorker 按worker分裂, 也就是将Graph分裂到各分布式任務上,常用于分布式運作時。分布式運作時,圖會經曆兩次分裂。先splitByWorker分裂到各分布式任務上,一般是各分布式機器。然後splitbydevice二次分裂到分布式機器的CPU GPU等裝置上。
static string SplitByWorker(const Node* node) {
string task;
string device;
DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task, &device);
return task;
}
5 Graph執行
Graph經過master剪枝和分裂後,就可以在本地的各CPU GPU裝置上執行了。這個過程的管理者叫worker。一般一個worker對應一個分裂後的子圖partitionGraph。每個worker啟動一個執行器Executor,入度為0的節點資料依賴已經ready了,故可以并行執行。等所有Executor執行完畢後,通知執行完畢。
各CPU GPU裝置間可能需要資料通信,通過建立send/recv節點來解決。資料發送方建立send節點,将資料放在send節點内,不阻塞。資料接收方建立recv節點,從recv節點中取出資料,recv節點中如果沒有資料則阻塞。這又是一個典型的生産者-消費者關系。
Graph執行的代碼邏輯在 http://direct_session.cc 檔案的DirectSession::Run()方法中。代碼邏輯很長,我們抽取其中的關鍵部分。
Status DirectSession::Run(const RunOptions& run_options,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
// 1 将輸入tensor的name取出,組成一個清單,友善之後快速索引輸入tensor
std::vector<string> input_tensor_names;
input_tensor_names.reserve(inputs.size());
for (const auto& it : inputs) {
input_tensor_names.push_back(it.first);
}
// 2 傳遞輸入資料給executor,通過FunctionCallFrame方式。
// 2.1 建立FunctionCallFrame,用來輸入資料給executor,并從executor中取出資料。
FunctionCallFrame call_frame(executors_and_keys->input_types,
executors_and_keys->output_types);
// 2.2 構造輸入資料feed_args
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
for (const auto& it : inputs) {
if (it.second.dtype() == DT_RESOURCE) {
Tensor tensor_from_handle;
ResourceHandleToInputTensor(it.second, &tensor_from_handle);
feed_args[executors_and_keys->input_name_to_index[it.first]] = tensor_from_handle;
} else {
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
}
}
// 2.3 将feed_args輸入資料設定到Arg節點上
const Status s = call_frame.SetArgs(feed_args);
// 3 開始執行executor
// 3.1 建立run_state, 和IntraProcessRendezvous
RunState run_state(args.step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
CancellationManager step_cancellation_manager;
args.call_frame = &call_frame;
// 3.2 建立ExecutorBarrier,它是一個執行完成的計數器。同時注冊執行完成的監聽事件executors_done.Notify()
const size_t num_executors = executors_and_keys->items.size();
ExecutorBarrier* barrier = new ExecutorBarrier(
num_executors, run_state.rendez, [&run_state](const Status& ret) {
{
mutex_lock l(run_state.mu_);
run_state.status.Update(ret);
}
// 所有線程池計算完畢後,會觸發Notify,發送消息。
run_state.executors_done.Notify();
});
args.rendezvous = run_state.rendez;
args.cancellation_manager = &step_cancellation_manager;
args.session_state = &session_state_;
args.tensor_store = &run_state.tensor_store;
args.step_container = &run_state.step_container;
args.sync_on_finish = sync_on_finish_;
// 3.3 建立executor的運作器Runner
Executor::Args::Runner default_runner = [this,
pool](Executor::Args::Closure c) {
SchedClosure(pool, std::move(c));
};
// 3.4 依次啟動所有executor,開始運作
for (const auto& item : executors_and_keys->items) {
item.executor->RunAsync(args, barrier->Get());
}
// 3.5 阻塞,收到所有executor執行完畢的通知
WaitForNotification(&run_state, &step_cancellation_manager, operation_timeout_in_ms_);
// 4 接收執行器執行完畢的輸出值
if (outputs) {
// 4.1 從RetVal節點中得到輸出值sorted_outputs
std::vector<Tensor> sorted_outputs;
const Status s = call_frame.ConsumeRetvals(&sorted_outputs);
// 4.2 處理原始輸出sorted_outputs,儲存到最終的輸出outputs中
outputs->clear();
outputs->reserve(sorted_outputs.size());
for (int i = 0; i < output_names.size(); ++i) {
const string& output_name = output_names[i];
if (first_indices.empty() || first_indices[i] == i) {
outputs->emplace_back(
std::move(sorted_outputs[executors_and_keys->output_name_to_index[output_name]]));
} else {
outputs->push_back((*outputs)[first_indices[i]]);
}
}
}
// 5 儲存輸出的tensor
run_state.tensor_store.SaveTensors(output_names, &session_state_));
return Status::OK();
}
主要步驟如下
- 将輸入tensor的name取出,組成一個清單,友善之後快速索引輸入tensor
- 傳遞輸入資料給executor,通過FunctionCallFrame方式。本地運作時因為在同一個程序中,我們采用FunctionCallFrame函數調用的方式來實作資料傳遞。将輸入資料傳遞給Arg節點,從RetVal節點中取出資料。
- 開始執行executor,并注冊監聽器。所有executor執行完畢後,會觸發executors_done.Notify()事件。然後目前線程wait阻塞,等待收到執行完畢的消息。
- 收到執行完畢的消息後,從RetVal節點中取出輸出值,經過簡單處理後,就可以最終輸出了
- 儲存輸出的tensor,友善以後使用。
6 總結
本文主要講解了TensorFlow的本地運作時,牢牢抓住session和graph兩個對象即可。Session的生命周期前文講解過,本文主要講解了Graph的生命周期,包括建構與傳遞,剪枝,分裂和執行。Graph是TensorFlow的核心對象,很多問題都是圍繞它來進行的,了解它有一定難度,但十分關鍵。文章中可能有一些了解不正确的地方,希望小夥伴們不吝賜教。
系列文章,歡迎閱讀
謝楊易:Tensorflow源碼解析1 -- 核心架構和源碼結構zhuanlan.zhihu.com

謝楊易:Tensorflow源碼解析2 -- 前後端連接配接的橋梁 - Sessionzhuanlan.zhihu.com
謝楊易:Tensorflow源碼解析3 -- TensorFlow核心對象 - Graphzhuanlan.zhihu.com
謝楊易:Tensorflow源碼解析4 -- 圖的節點 - Operationzhuanlan.zhihu.com
謝楊易:Tensorflow源碼解析5 -- 圖的邊 - Tensorzhuanlan.zhihu.com
謝楊易:Tensorflow源碼解析6 -- TensorFlow本地運作時zhuanlan.zhihu.com
謝楊易:Tensorflow源碼解析7 -- TensorFlow分布式運作時zhuanlan.zhihu.com