我希望在运行时以0 = (n,m)
的张量为多个位置分配0。
我使用Tensorflow中的where子句计算了索引,并调用了scatter_nd_update
函数,以便在新发现的多个位置分配一个tf.constant(0)
。
oscvec = tf.where(tf.math.logical_and(sgn2 > 0, sgn1 < 0))
updates = tf.placeholder(tf.float64, [None, None])
oscvec_empty = tf.placeholder(tf.int64, [None])
tf.cond(tf.not_equal(tf.size(oscvec), 0), tf.scatter_nd_update(save_parms, oscvec, tf.constant(0, dtype=tf.float64)),
tf.scatter_nd_update(save_parms, oscvec_empty, updates))
我希望在条件不满足时tf.where
返回一个空张量,并且在某个时候返回save_parms
的索引的非空张量。我决定创建并清空oscvec_empty
张量以处理tf.where
的结果返回空张量的情况。但这似乎不起作用。...从以下错误中可以看出,当使用Tensorflow if-else条件tf.cond
-通过{{1来更新save_parms
参数张量时,会产生以下错误}}函数:
tf.scatter_nd_update
当oscvec为非空时,是否有一种方法可以替换ValueError: Shape must be at least rank 1 but is rank 0 for 'ScatterNdUpdate' (op: 'ScatterNdUpdate') with input shapes: [55], [?,1], [].
张量中多个位置上的值,而当oscvec为空时,是否可以这样做呢? save_parms
张量对应于基于给定准则应用于sgn
的符号函数的结果。
答案 0 :(得分:1)
您可以使用tf.where()
来代替这种复杂的方法。
import tensorflow as tf
vec1 = tf.constant([[ 0.05734377, 0.80147606, -1.2730557 ], [ 0.42826906, 1.1943488 , -0.10129673]])
vec2 = tf.constant([[ 1.5461133 , -0.38455755, -0.79792875], [ 1.5374309 , -1.5657802 , 0.05546811]])
sgn1 = tf.sign(vec1)
sgn2 = tf.sign(vec2)
save_parms = tf.random_normal(shape=sgn1.shape)
oscvec = tf.where(tf.math.logical_and(sgn2 > 0, sgn1 < 0),tf.zeros_like(save_parms),save_parms)
with tf.Session() as sess:
save_parms_val, oscvec_val = sess.run([save_parms, oscvec])
print(save_parms_val)
print(oscvec_val)
[[ 0.75645643 -0.646291 -1.2194813 ]
[ 1.5204562 -1.0625905 2.9939709 ]]
[[ 0.75645643 -0.646291 -1.2194813 ]
[ 1.5204562 -1.0625905 0. ]]