最近在實驗中,需要用到tensorflow建立一個簡單的模型,但鑒于部分要求比較苛刻,不能直接使用其内置的layer,是以需要自定義一個layer類,這便涉及到了對
__init__()
,
build()
,
call()
這三個函數的了解
先看官方手冊中使用了Layer中的這三個關鍵函數的一個簡單的執行個體:
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
def call(self, input):
return tf.matmul(input, self.kernel)
layer = MyDenseLayer(10)
從直覺上了解,似乎
__init__()
和
build()
函數都在對Layer進行初始化,都初始化了一些成員函數,而
call()
函數則是在該layer被調用時執行。
顯然,這三個函數都是從
tf.keras.layers.Layer
處繼承而來的,那麼不妨看一下官方對這幾個函數作何解釋。
下圖為
tf.keras.layers.Layer
的官方文檔
簡單翻譯,就是說官方推薦凡是
tf.keras.layers.Layer
的派生類都要實作
__init__()
,
build()
,
call()
這三個方法
__init__()
:儲存成員變量的設定
build()
:在
call()
函數第一次執行時會被調用一次,這時候可以知道輸入資料的
shape
。傳回去看一看,果然是
__init__()
函數中隻初始化了輸出資料的
shape
,而輸入資料的
shape
需要在
build()
函數中動态擷取,這也解釋了為什麼在有
__init__()
函數時還需要使用
build()
函數
call()
:
call()
函數就很簡單了,即當其被調用時會被執行。下面附上這幾個函數的文檔,就不做詳細介紹了,有興趣可以自己看看: