我尝试使用tensorflow中的map_fn将变换应用于列向量,但它无效。
对于以下列向量:
elems = np.array([[1.0], [2.0], [3.0]])
当我这样做时:
tf_m = tf.map_fn(lambda x: x + 1.0, elems)
with tf.Session() as sess:
res = sess.run(tf_m)
print(str(res))
我得到了我期望的结果,即这个列向量:
[[2.]
[3.]
[4.]]
然而,当我这样做时:
tf_m2 = tf.map_fn(lambda x: x+1 if x % 2 > 0 else x, elems)
with tf.Session() as sess:
res = sess.run(tf_m2)
print(str(res))
代码失败,出现以下异常:
TypeError:不允许使用
tf.Tensor
作为Pythonbool
。使用if t is not None:
代替if t:
来测试是否定义了张量,并使用TensorFlow操作(如tf.cond)来执行以张量值为条件的子图。
我尝试过打印x的类型,它是一个有形状的张量(1,)。所以,它看起来正在发生的是,值不是作为标量值传递给lambda,而是作为具有形状(1,)的张量传递; %被广播,产生另一个形状的张量(1,),但该张量不能应用> =运算符。
有没有办法让这项工作?有没有办法获得实际标量我可以应用> =运算符?如果没有,是否可以使用map_fn的有效替代方案?
(我已经看过tf.cond了,在这种情况下我怎么能用它是不明显的。据我所知,tf.cond会产生一个op,而不是一个可调用的,所以如何我是否会在map_fn应用的lambda中使用它?)
答案 0 :(得分:2)
elems_shape = tf.shape(elems)
elems_flat = tf.reshape(elems, [-1])
tf_m2_flat = tf.map_fn(lambda x: tf.cond(x % 2 > 0, lambda: x + 1, lambda: x), elems_flat)
tf_m2 = tf.reshape(tf_m2_flat, elems_shape)
但您也可以像这样简单地使用tf.where
:
tf_m2 = tf.where(elems % 2 > 0, elems + 1, elems)