天天看点

tensorflow2.0中Layer的__init__(),build(), call()函数

最近在实验中,需要用到tensorflow建立一个简单的模型,但鉴于部分要求比较苛刻,不能直接使用其内置的layer,因此需要自定义一个layer类,这便涉及到了对​

​__init__()​

​​, ​

​build()​

​​, ​

​call()​

​这三个函数的理解

先看​​官方手册​​中使用了Layer中的这三个关键函数的一个简单的实例:

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_variable("kernel",
                                    shape=[int(input_shape[-1]),
                                           self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

layer = MyDenseLayer(10)      

从直观上理解,似乎​

​__init__()​

​​和​

​build()​

​​函数都在对Layer进行初始化,都初始化了一些成员函数,而​

​call()​

​函数则是在该layer被调用时执行。

显然,这三个函数都是从​

​tf.keras.layers.Layer​

​处继承而来的,那么不妨看一下官方对这几个函数作何解释。

下图为​

​tf.keras.layers.Layer​

​的官方文档

tensorflow2.0中Layer的__init__(),build(), call()函数

简单翻译,就是说官方推荐凡是​

​tf.keras.layers.Layer​

​​的派生类都要实现​

​__init__()​

​​,​

​build()​

​​, ​

​call()​

​这三个方法

​__init__()​

​:保存成员变量的设置

​build()​

​​:在​

​call()​

​​函数第一次执行时会被调用一次,这时候可以知道输入数据的​

​shape​

​​。返回去看一看,果然是​

​__init__()​

​​函数中只初始化了输出数据的​

​shape​

​​,而输入数据的​

​shape​

​​需要在​

​build()​

​​函数中动态获取,这也解释了为什么在有​

​__init__()​

​​函数时还需要使用​

​build()​

​函数

​call()​

​​: ​

​call()​

​函数就很简单了,即当其被调用时会被执行。下面附上这几个函数的文档,就不做详细介绍了,有兴趣可以自己看看:

tensorflow2.0中Layer的__init__(),build(), call()函数
tensorflow2.0中Layer的__init__(),build(), call()函数