近期看batch_normalization的代码时碰到一个tf.cond()函数,特此记录。
tf.cond()函数用于控制数据流向。
通过网上查了一些文章之后,才发现使用tf.cond() 函数是控制数据流向。也就是说在TensorFlow中,tf.cond()类似于c语言中的if…else…,但是也仅仅只是类似而已。
首先看一下官方文档:
# 用于有条件的执行函数,当pred为True时,执行true_fn函数,否则执行false_fn函数 tf.cond( pred, true_fn=None, false_fn=None, strict=False, name=None, fn1=None, fn2=None )
参数:
- pred:标量决定是否返回 true_fn 或 false_fn 结果。
- true_fn:如果 pred 为 true,则被调用。
- false_fn:如果 pred 为 false,则被调用。
- strict:启用/禁用 “严格”模式的布尔值。
- name:返回的张量的可选名称前缀。
Return:
通过调用 true_fn 或 false_fn 返回的张量。如果 callables 返回单一实例列表, 则从列表中提取元素。
需要注意的是,pred参数是tf.bool型变量,直接写“True”或者“False”是python型bool,会报错的。因此可以是很使用tf.equal(is_training,True)的操作。
再看一下官方例子:
z = tf.multiply(a, b) result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
If x < y, the tf.add operation will be executed and tf.square operation will not be executed. Since z is needed for at least one branch of the cond, the tf.multiply operation is always executed, unconditionally.
如果x < y,将会执行tf.add操作,不会执行tf.square操作。因为cond中至少有一个分支需要z,而tf.multiply操作总是被无条件地执行。
但是我们来看这个操作,其实是反直觉的,因为按照一般的逻辑来说,应该是用不到就不执行了,通过查询官方文档,我们看到了这么一番话:
Although this behavior is consistent with the dataflow model of TensorFlow,it has occasionally surprised some users who expected a lazier semantics。
虽然这种行为与 TensorFlow 的数据流模型是一致的,但有时候,还是会让有些期望慵懒的用户惊讶。
真是善解人意、、、那么我们来看看例子理解一下:
import tensorflow as tf a=tf.constant(2) b=tf.constant(3) x=tf.constant(4) y=tf.constant(5) z = tf.multiply(a, b) result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) with tf.Session() as session: print(result.eval()) print(z.eval()) print(y.eval())
> 10 > 6 > 5
首先z = a * b = 2 * 3 = 6,然后在tf.cond()函数中,因为x<y(4<5)成立,所以执行lambda: tf.add(x, z),也就是result = x + z = 10,而不执行lambda: tf.square(y),但是执行了z = tf.multiply(a, b)。