天天看點

tensorflow2.0中Layer的__init__(),build(), call()函數

最近在實驗中,需要用到tensorflow建立一個簡單的模型,但鑒于部分要求比較苛刻,不能直接使用其内置的layer,是以需要自定義一個layer類,這便涉及到了對​

​__init__()​

​​, ​

​build()​

​​, ​

​call()​

​這三個函數的了解

先看​​官方手冊​​中使用了Layer中的這三個關鍵函數的一個簡單的執行個體:

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_variable("kernel",
                                    shape=[int(input_shape[-1]),
                                           self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

layer = MyDenseLayer(10)      

從直覺上了解,似乎​

​__init__()​

​​和​

​build()​

​​函數都在對Layer進行初始化,都初始化了一些成員函數,而​

​call()​

​函數則是在該layer被調用時執行。

顯然,這三個函數都是從​

​tf.keras.layers.Layer​

​處繼承而來的,那麼不妨看一下官方對這幾個函數作何解釋。

下圖為​

​tf.keras.layers.Layer​

​的官方文檔

tensorflow2.0中Layer的__init__(),build(), call()函數

簡單翻譯,就是說官方推薦凡是​

​tf.keras.layers.Layer​

​​的派生類都要實作​

​__init__()​

​​,​

​build()​

​​, ​

​call()​

​這三個方法

​__init__()​

​:儲存成員變量的設定

​build()​

​​:在​

​call()​

​​函數第一次執行時會被調用一次,這時候可以知道輸入資料的​

​shape​

​​。傳回去看一看,果然是​

​__init__()​

​​函數中隻初始化了輸出資料的​

​shape​

​​,而輸入資料的​

​shape​

​​需要在​

​build()​

​​函數中動态擷取,這也解釋了為什麼在有​

​__init__()​

​​函數時還需要使用​

​build()​

​函數

​call()​

​​: ​

​call()​

​函數就很簡單了,即當其被調用時會被執行。下面附上這幾個函數的文檔,就不做詳細介紹了,有興趣可以自己看看:

tensorflow2.0中Layer的__init__(),build(), call()函數
tensorflow2.0中Layer的__init__(),build(), call()函數