TensorFlow c++ SessionFactory注冊與No session factory registered錯誤
背景
近期我們在伺服器上使用TensorFlow來進行推理,作為雲推理服務的基礎。具體如何使用tensorflow c++庫來進行推斷可以參考之前的記錄《從0開始使用tensorflow的c++庫進行模型推斷》。
在經過一些demo的驗證之後,我們開始在項目中正式使用TensorFlow的c++庫,簡單描述我們的用法:tensorflow作為最底層,上面有一層對推理架構的封裝靜态庫libnn.a,然後是main函數調用這個靜态庫和tf的動态庫完成推理。在這個過程中執行NewSession()時遇到報錯:
No session factory registered for the given session options:{target: ""} Registered factories are {}
。
解決辦法
參考網上的各種辦法,均未解決。
目前我的解決辦法是,在main函數的源檔案中顯示增加對TensorFlow相關頭檔案的引用,随便include以下頭檔案之一,然後再編譯可執行檔案,即可正确加載并運作,當然其他類似檔案也可以,不過我隻驗證了這三個:
“tensorflow/core/public/session.h”
“tensorflow/client/client_session.h”
“tensorflow/core/framework/tensor.h”
問題分析
其實這個才是本文的主要部分。
從字面意思來看,這個報錯是因為Session factory注冊失敗。
好!那我們就從什麼是Session factory開始吧!
1. Session Factory
在tensorflow/core/common_runtime/session_factory.h中,我們找到了這個類:
class Session;
struct SessionOptions;
class SessionFactory {
public:
virtual Status NewSession(const SessionOptions& options,
Session** out_session) = 0;
virtual bool AcceptsOptions(const SessionOptions& options) = 0;
virtual Status Reset(const SessionOptions& options,
const std::vector<string>& containers) {
return errors::Unimplemented("Reset()");
}
static void Register(const string& runtime_type, SessionFactory* factory);
static Status GetFactory(const SessionOptions& options,
SessionFactory** out_factory);
}
首先在這個頭檔案開始,聲明了兩個類Session 以及 SessionOptions,接下來再看成員函數名
NewSession, AcceptsOptions, Reset, Register, GetFactory
,我們大概可以猜到這個Factory是用來完成SessionFactory的注冊,以及建立新的Session,過程中需要使用SessionOptions作為配置。實際上也的确如此。
如果我們在TensorFlow的源碼中搜一下
No session factory registered...
報錯,就可以發現這個錯誤的直接來源就是SessionFactory::GetFactory異常:
Status SessionFactory::GetFactory(const SessionOptions& options,
SessionFactory** out_factory) {
......
} else {
return errors::NotFound(
"No session factory registered for the given session options: {",
SessionOptionsToString(options), "} ",
RegisteredFactoriesErrorMessageLocked());
}
到此我們知道了報錯的具體位置以及報錯的
表層原因
:
- main函數調用libnn.a中的init函數
- init函數調用Status NewSession(const SessionOptions& options, Session** out_session)函數
- 調用SessionFactory::GetFactory函數
- 報錯
2. SessionFactory的注冊
前面我們已經了解到報錯的表層原因是因為執行GetFactory傳回異常,接下來我們仔細分析一下為什麼會出現這個異常,我們還是從SessionFactory::GetFactory入手:
Status SessionFactory::GetFactory(const SessionOptions& options,
SessionFactory** out_factory) {
...
std::vector<std::pair<string, SessionFactory*>> candidate_factories;
for (const auto& session_factory : *session_factories()) {
if (session_factory.second->AcceptsOptions(options)) {
VLOG(2) << "SessionFactory type " << session_factory.first
<< " accepts target: " << options.target;
candidate_factories.push_back(session_factory);
} else {
VLOG(2) << "SessionFactory type " << session_factory.first
<< " does not accept target: " << options.target;
}
}
if (candidate_factories.size() == 1) {
*out_factory = candidate_factories[0].second;
return Status::OK();
} else if (candidate_factories.size() > 1) {
...
else {
return errors::NotFound(
"No session factory registered for the given session options: {",
SessionOptionsToString(options), "} ",
RegisteredFactoriesErrorMessageLocked());
}
}
報錯異常的直接原因,是因為candidate_factories<1,而這個候選工廠數量,是前面for循環得到的滿足條件的factory。
滿足什麼條件呢?又是從什麼集合裡面去篩選呢?
我們一個個來分析,首先需要滿足的條件是我們在外面定義的const SessionOptions& options,要建立Session,就必須傳入一個SessionOptions對象,一般來說,如果不做一些細節調優,我們會在NewSession中傳入一個SessionOptions的預設構造對象SessionOptions(),這個預設構造的對象隻包含基本的環境變量,對所有factories來說都可以通過AcceptsOptions的判斷的。
從什麼集合裡面去篩選候選Factory呢?在代碼裡面,就是for循環中的*session_factories(),具體來說就是:
typedef std::unordered_map<string, SessionFactory*> SessionFactories;
SessionFactories* session_factories() {
static SessionFactories* factories = new SessionFactories;
return factories;
}
當調用這個函數的時候,傳回factories對象,即一個unordered_map,key是string,value是SessionFactory指針,即session_factory.second。
問題又來了,這個unordered_map,又是什麼時候被指派的呢?
在源檔案中搜一下
insert
,就發現了這個函數:
void SessionFactory::Register(const string& runtime_type,
SessionFactory* factory) {
mutex_lock l(*get_session_factory_lock());
if (!session_factories()->insert({runtime_type, factory}).second) {
LOG(ERROR) << "Two session factories are being registered "
<< "under" << runtime_type;
}
}
每次調用Register,都會将一組key為runtime_type,value為factory的鍵值對放入unordered_map中,由于是map,會自動進行判重,也就是說每個runtime_type對應一個factory。
接下來的問題是,什麼時候調用這個Register?
這個不太好找,涉及到TensorFlow的架構設計了,簡單來說就是TensorFlow有兩種runtime_type:direct和grpc,分别在
tensorflow/core/common_runtime/direct_session.cc
和
tensorflow/core/common_runtime/grpc_session.cc
中實作。
具體來說,是通過定義一個注冊機DirectSessionRegistrar類,然後定義全局靜态變量registrar,在registrar的初始化中,調用SessionFactory::Register方法,完成對應runtime_type的注冊:
class DirectSessionRegistrar {
public:
DirectSessionRegistrar() {
SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
}
};
static DirectSessionRegistrar registrar;
class GrpcSessionRegistrar {
public:
GrpcSessionRegistrar() {
SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
}
};
static GrpcSessionRegistrar registrar;
對Session的架構設計有更多的興趣的,可以參考『深度長文』Tensorflow代碼解析(五),這裡就不再贅述了。
3. 動态庫中全局靜态變量的初始化
在上面我們已經分析清楚,這個報錯的根本原因是TensorFlow動态庫中的這兩個源檔案中的全局靜态變量registrar沒有初始化,導緻後面在建立Session的時候GetFactory失敗。
有一篇文章"Integrating TensorFlow libraries"也分析是這個原因。因為TF元件化的設計思路,一個很小的core+注冊的方式來完成包括SessionFactory、Op_Kernel、Op等的加載,在編譯的時候,如果這部分代碼被編譯器忽略了,就會導緻部分代碼注冊失敗,進而導緻報錯。
為什麼不初始化呢?
動态庫的加載分為顯示加載和隐式加載,我們後面驗證了顯示加載,即顯示通過dlopen來調用這些動态庫,是可以正常運作的,說明這種情況下,所有的注冊都完成了。問題出在隐式加載上。
為什麼動态庫隐式加載的情況下一些全局靜态變量沒有初始化呢?
這就是前面說過的解決辦法了,隐式加載動态庫,标準用法是在可執行檔案的代碼中顯示include動态庫export的函數的頭檔案,然後編譯時加上對動态庫的連結,這樣系統會自動在運作前加載動态庫。我們雖然在靜态庫中include了TF動态庫的頭檔案,但是并未在可執行程式中引用,是以會導緻TF動态庫的全局靜态變量初始化失敗!
總結
- 使用隐式調用時,則調用方必須要加上動态庫中的頭檔案,g++編譯時還需要要用參數-I指明包含的頭檔案的位置
- 進一步通過SessionFactory學習了TensorFlow的子產品化設計
- 深入Debug是一件很好玩的事情
- 千裡之堤毀于蟻穴,一定要注意基礎中的細節
參考
- linux下動态連結庫(.so)的顯式調用和隐式調用
- C語言中的 static變量(全局和局部)、static函數總結
- Integrating TensorFlow libraries
- 深度長文 Tensorflow代碼解析(五)
- TensorFlow源碼