天天看点

tf.cond()函数解析(最清晰的解释)

近期看batch_normalization的代码时碰到一个tf.cond()函数,特此记录。

tf.cond()函数用于控制数据流向。

通过网上查了一些文章之后,才发现使用tf.cond() 函数是控制数据流向。也就是说在TensorFlow中,tf.cond()类似于c语言中的if…else…,但是也仅仅只是类似而已。

首先看一下官方文档:

# 用于有条件的执行函数,当pred为True时,执行true_fn函数,否则执行false_fn函数     tf.cond(     pred,     true_fn=None,     false_fn=None,     strict=False,     name=None,     fn1=None,     fn2=None     )      

参数:

  • pred:标量决定是否返回 true_fn 或 false_fn 结果。
  • true_fn:如果 pred 为 true,则被调用。
  • false_fn:如果 pred 为 false,则被调用。
  • strict:启用/禁用 “严格”模式的布尔值。
  • name:返回的张量的可选名称前缀。

Return:

通过调用 true_fn 或 false_fn 返回的张量。如果 callables 返回单一实例列表, 则从列表中提取元素。

需要注意的是,pred参数是tf.bool型变量,直接写“True”或者“False”是python型bool,会报错的。因此可以是很使用tf.equal(is_training,True)的操作。

再看一下官方例子:

z = tf.multiply(a, b)     result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))      

If x < y, the tf.add operation will be executed and tf.square operation will not be executed. Since z is needed for at least one branch of the cond, the tf.multiply operation is always executed, unconditionally.

如果x < y,将会执行tf.add操作,不会执行tf.square操作。因为cond中至少有一个分支需要z,而tf.multiply操作总是被无条件地执行。

但是我们来看这个操作,其实是反直觉的,因为按照一般的逻辑来说,应该是用不到就不执行了,通过查询官方文档,我们看到了这么一番话:

Although this behavior is consistent with the dataflow model of TensorFlow,it has occasionally surprised some users who expected a lazier semantics。

虽然这种行为与 TensorFlow 的数据流模型是一致的,但有时候,还是会让有些期望慵懒的用户惊讶。

真是善解人意、、、那么我们来看看例子理解一下:

import tensorflow as tf     a=tf.constant(2)         b=tf.constant(3)         x=tf.constant(4)         y=tf.constant(5)         z = tf.multiply(a, b)         result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))         with tf.Session() as session:         print(result.eval())     print(z.eval())     print(y.eval())      
> 10     > 6     > 5      

首先z = a * b = 2 * 3 = 6,然后在tf.cond()函数中,因为x<y(4<5)成立,所以执行lambda: tf.add(x, z),也就是result = x + z = 10,而不执行lambda: tf.square(y),但是执行了z = tf.multiply(a, b)。