在运行时在张量中的多个位置查找索引,并将其替换为0

时间:2019-06-10 00:55:58

标签: python tensorflow

我希望在运行时以0 = (n,m)的张量为多个位置分配0。

我使用Tensorflow中的where子句计算了索引,并调用了scatter_nd_update函数,以便在新发现的多个位置分配一个tf.constant(0)

oscvec = tf.where(tf.math.logical_and(sgn2 > 0, sgn1 < 0))
updates = tf.placeholder(tf.float64, [None, None])
oscvec_empty = tf.placeholder(tf.int64, [None])
tf.cond(tf.not_equal(tf.size(oscvec), 0), tf.scatter_nd_update(save_parms, oscvec, tf.constant(0, dtype=tf.float64)),
        tf.scatter_nd_update(save_parms, oscvec_empty, updates))

我希望在条件不满足时tf.where返回一个空张量,并且在某个时候返回save_parms的索引的非空张量。我决定创建并清空oscvec_empty张量以处理tf.where的结果返回空张量的情况。但这似乎不起作用。...从以下错误中可以看出,当使用Tensorflow if-else条件tf.cond-通过{{1来更新save_parms参数张量时,会产生以下错误}}函数:

tf.scatter_nd_update

当oscvec为非空时,是否有一种方法可以替换ValueError: Shape must be at least rank 1 but is rank 0 for 'ScatterNdUpdate' (op: 'ScatterNdUpdate') with input shapes: [55], [?,1], []. 张量中多个位置上的值,而当oscvec为空时,是否可以这样做呢? save_parms张量对应于基于给定准则应用于sgn的符号函数的结果。

1 个答案:

答案 0 :(得分:1)

您可以使用tf.where()来代替这种复杂的方法。

import tensorflow as tf

vec1 = tf.constant([[ 0.05734377, 0.80147606, -1.2730557 ], [ 0.42826906, 1.1943488 , -0.10129673]])
vec2 = tf.constant([[ 1.5461133 , -0.38455755, -0.79792875], [ 1.5374309 , -1.5657802 , 0.05546811]])

sgn1 = tf.sign(vec1)
sgn2 = tf.sign(vec2)
save_parms = tf.random_normal(shape=sgn1.shape)
oscvec = tf.where(tf.math.logical_and(sgn2 > 0, sgn1 < 0),tf.zeros_like(save_parms),save_parms)

with tf.Session() as sess:
    save_parms_val, oscvec_val = sess.run([save_parms, oscvec])
    print(save_parms_val)
    print(oscvec_val)

[[ 0.75645643 -0.646291   -1.2194813 ]
 [ 1.5204562  -1.0625905   2.9939709 ]]
[[ 0.75645643 -0.646291   -1.2194813 ]
 [ 1.5204562  -1.0625905   0.        ]]