天天看點

TensorFlow c++ SessionFactory注冊與No session factory registered錯誤

TensorFlow c++ SessionFactory注冊與No session factory registered錯誤

背景

近期我們在伺服器上使用TensorFlow來進行推理,作為雲推理服務的基礎。具體如何使用tensorflow c++庫來進行推斷可以參考之前的記錄《從0開始使用tensorflow的c++庫進行模型推斷》。

在經過一些demo的驗證之後,我們開始在項目中正式使用TensorFlow的c++庫,簡單描述我們的用法:tensorflow作為最底層,上面有一層對推理架構的封裝靜态庫libnn.a,然後是main函數調用這個靜态庫和tf的動态庫完成推理。在這個過程中執行NewSession()時遇到報錯:

No session factory registered for the given session options:{target: ""} Registered factories are {}

解決辦法

參考網上的各種辦法,均未解決。

目前我的解決辦法是,在main函數的源檔案中顯示增加對TensorFlow相關頭檔案的引用,随便include以下頭檔案之一,然後再編譯可執行檔案,即可正确加載并運作,當然其他類似檔案也可以,不過我隻驗證了這三個:

“tensorflow/core/public/session.h”

“tensorflow/client/client_session.h”

“tensorflow/core/framework/tensor.h”

問題分析

其實這個才是本文的主要部分。

從字面意思來看,這個報錯是因為Session factory注冊失敗。

好!那我們就從什麼是Session factory開始吧!

1. Session Factory

在tensorflow/core/common_runtime/session_factory.h中,我們找到了這個類:

class Session;
struct SessionOptions;
class SessionFactory {
public:
	virtual Status NewSession(const SessionOptions& options,
	                          Session** out_session) = 0;
	virtual bool AcceptsOptions(const SessionOptions& options) = 0;
	virtual Status Reset(const SessionOptions& options,
                       const std::vector<string>& containers) {
	    return errors::Unimplemented("Reset()");
	}
	static void Register(const string& runtime_type, SessionFactory* factory);
	static Status GetFactory(const SessionOptions& options,
                           SessionFactory** out_factory);
}                          
           

首先在這個頭檔案開始,聲明了兩個類Session 以及 SessionOptions,接下來再看成員函數名

NewSession, AcceptsOptions, Reset, Register, GetFactory

,我們大概可以猜到這個Factory是用來完成SessionFactory的注冊,以及建立新的Session,過程中需要使用SessionOptions作為配置。實際上也的确如此。

如果我們在TensorFlow的源碼中搜一下

No session factory registered...

報錯,就可以發現這個錯誤的直接來源就是SessionFactory::GetFactory異常:

Status SessionFactory::GetFactory(const SessionOptions& options,
                                  SessionFactory** out_factory) {
	......
	} else {
    return errors::NotFound(
        "No session factory registered for the given session options: {",
        SessionOptionsToString(options), "} ",
        RegisteredFactoriesErrorMessageLocked());
  }

           

到此我們知道了報錯的具體位置以及報錯的

表層原因

  1. main函數調用libnn.a中的init函數
  2. init函數調用Status NewSession(const SessionOptions& options, Session** out_session)函數
  3. 調用SessionFactory::GetFactory函數
  4. 報錯

2. SessionFactory的注冊

前面我們已經了解到報錯的表層原因是因為執行GetFactory傳回異常,接下來我們仔細分析一下為什麼會出現這個異常,我們還是從SessionFactory::GetFactory入手:

Status SessionFactory::GetFactory(const SessionOptions& options,
                                  SessionFactory** out_factory) {
	...
	std::vector<std::pair<string, SessionFactory*>> candidate_factories;
	for (const auto& session_factory : *session_factories()) {
	 if (session_factory.second->AcceptsOptions(options)) {
	   VLOG(2) << "SessionFactory type " << session_factory.first
	           << " accepts target: " << options.target;
	   candidate_factories.push_back(session_factory);
	 } else {
	   VLOG(2) << "SessionFactory type " << session_factory.first
	           << " does not accept target: " << options.target;
	 }
	}
	
	if (candidate_factories.size() == 1) {
	    *out_factory = candidate_factories[0].second;
	    return Status::OK();
	} else if (candidate_factories.size() > 1) {
		 ...
	else {
	    return errors::NotFound(
	        "No session factory registered for the given session options: {",
	        SessionOptionsToString(options), "} ",
	        RegisteredFactoriesErrorMessageLocked());
	}
}
           

報錯異常的直接原因,是因為candidate_factories<1,而這個候選工廠數量,是前面for循環得到的滿足條件的factory。

滿足什麼條件呢?又是從什麼集合裡面去篩選呢?

我們一個個來分析,首先需要滿足的條件是我們在外面定義的const SessionOptions& options,要建立Session,就必須傳入一個SessionOptions對象,一般來說,如果不做一些細節調優,我們會在NewSession中傳入一個SessionOptions的預設構造對象SessionOptions(),這個預設構造的對象隻包含基本的環境變量,對所有factories來說都可以通過AcceptsOptions的判斷的。

從什麼集合裡面去篩選候選Factory呢?在代碼裡面,就是for循環中的*session_factories(),具體來說就是:

typedef std::unordered_map<string, SessionFactory*> SessionFactories;
SessionFactories* session_factories() {
	static SessionFactories* factories = new SessionFactories;
	return factories;
}
           

當調用這個函數的時候,傳回factories對象,即一個unordered_map,key是string,value是SessionFactory指針,即session_factory.second。

問題又來了,這個unordered_map,又是什麼時候被指派的呢?

在源檔案中搜一下

insert

,就發現了這個函數:

void SessionFactory::Register(const string& runtime_type,
                              SessionFactory* factory) {
  mutex_lock l(*get_session_factory_lock());
  if (!session_factories()->insert({runtime_type, factory}).second) {
    LOG(ERROR) << "Two session factories are being registered "
               << "under" << runtime_type;
  }
}
           

每次調用Register,都會将一組key為runtime_type,value為factory的鍵值對放入unordered_map中,由于是map,會自動進行判重,也就是說每個runtime_type對應一個factory。

接下來的問題是,什麼時候調用這個Register?

這個不太好找,涉及到TensorFlow的架構設計了,簡單來說就是TensorFlow有兩種runtime_type:direct和grpc,分别在

tensorflow/core/common_runtime/direct_session.cc

tensorflow/core/common_runtime/grpc_session.cc

中實作。

具體來說,是通過定義一個注冊機DirectSessionRegistrar類,然後定義全局靜态變量registrar,在registrar的初始化中,調用SessionFactory::Register方法,完成對應runtime_type的注冊:

class DirectSessionRegistrar {
 public:
  DirectSessionRegistrar() {
    SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
  }
};
static DirectSessionRegistrar registrar;
           
class GrpcSessionRegistrar {
 public:
  GrpcSessionRegistrar() {
    SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
  }
};
static GrpcSessionRegistrar registrar;
           

對Session的架構設計有更多的興趣的,可以參考『深度長文』Tensorflow代碼解析(五),這裡就不再贅述了。

3. 動态庫中全局靜态變量的初始化

在上面我們已經分析清楚,這個報錯的根本原因是TensorFlow動态庫中的這兩個源檔案中的全局靜态變量registrar沒有初始化,導緻後面在建立Session的時候GetFactory失敗。

有一篇文章"Integrating TensorFlow libraries"也分析是這個原因。因為TF元件化的設計思路,一個很小的core+注冊的方式來完成包括SessionFactory、Op_Kernel、Op等的加載,在編譯的時候,如果這部分代碼被編譯器忽略了,就會導緻部分代碼注冊失敗,進而導緻報錯。

為什麼不初始化呢?

動态庫的加載分為顯示加載和隐式加載,我們後面驗證了顯示加載,即顯示通過dlopen來調用這些動态庫,是可以正常運作的,說明這種情況下,所有的注冊都完成了。問題出在隐式加載上。

為什麼動态庫隐式加載的情況下一些全局靜态變量沒有初始化呢?

這就是前面說過的解決辦法了,隐式加載動态庫,标準用法是在可執行檔案的代碼中顯示include動态庫export的函數的頭檔案,然後編譯時加上對動态庫的連結,這樣系統會自動在運作前加載動态庫。我們雖然在靜态庫中include了TF動态庫的頭檔案,但是并未在可執行程式中引用,是以會導緻TF動态庫的全局靜态變量初始化失敗!

總結

  1. 使用隐式調用時,則調用方必須要加上動态庫中的頭檔案,g++編譯時還需要要用參數-I指明包含的頭檔案的位置
  2. 進一步通過SessionFactory學習了TensorFlow的子產品化設計
  3. 深入Debug是一件很好玩的事情
  4. 千裡之堤毀于蟻穴,一定要注意基礎中的細節

參考

  1. linux下動态連結庫(.so)的顯式調用和隐式調用
  2. C語言中的 static變量(全局和局部)、static函數總結
  3. Integrating TensorFlow libraries
  4. 深度長文 Tensorflow代碼解析(五)
  5. TensorFlow源碼

繼續閱讀