尝试在张量流中使用if map_fn

时间:2018-04-04 10:54:09

标签: python tensorflow

我尝试使用tensorflow中的map_fn将变换应用于列向量,但它无效。

对于以下列向量:

elems = np.array([[1.0], [2.0], [3.0]])

当我这样做时:

tf_m = tf.map_fn(lambda x: x + 1.0, elems)
with tf.Session() as sess:
    res = sess.run(tf_m)
    print(str(res))

我得到了我期望的结果,即这个列向量:

[[2.]
 [3.]
 [4.]]

然而,当我这样做时:

tf_m2 = tf.map_fn(lambda x: x+1 if x % 2 > 0 else x, elems)
with tf.Session() as sess:
    res = sess.run(tf_m2)
    print(str(res))

代码失败,出现以下异常:

  

TypeError:不允许使用tf.Tensor作为Python bool。使用if t is not None:代替if t:来测试是否定义了张量,并使用TensorFlow操作(如tf.cond)来执行以张量值为条件的子图。

我尝试过打印x的类型,它是一个有形状的张量(1,)。所以,它看起来正在发生的是,值不是作为标量值传递给lambda,而是作为具有形状(1,)的张量传递; %被广播,产生另一个形状的张量(1,),但该张量不能应用> =运算符。

有没有办法让这项工作?有没有办法获得实际标量我可以应用> =运算符?如果没有,是否可以使用map_fn的有效替代方案?

(我已经看过tf.cond了,在这种情况下我怎么能用它是不明显的。据我所知,tf.cond会产生一个op,而不是一个可调用的,所以如何我是否会在map_fn应用的lambda中使用它?)

1 个答案:

答案 0 :(得分:2)

您可以使用tf.map_fntf.cond执行此操作:

elems_shape = tf.shape(elems)
elems_flat = tf.reshape(elems, [-1])
tf_m2_flat = tf.map_fn(lambda x: tf.cond(x % 2 > 0, lambda: x + 1, lambda: x), elems_flat)
tf_m2 = tf.reshape(tf_m2_flat, elems_shape)

但您也可以像这样简单地使用tf.where

tf_m2 = tf.where(elems % 2 > 0, elems + 1, elems)