Tensorflow:获取数组行的索引为零

时间:2017-01-23 12:43:05

标签: arrays indexing tensorflow

对于张量

[[1 2 3 1]
 [0 0 0 0]
 [1 3 5 7]
 [0 0 0 0]
 [3 5 7 8]]

如何获取0行的索引?即Tensorflow中的列表[1,3]?

2 个答案:

答案 0 :(得分:3)

据我所知,你不能像在NumPy这样的高级库中那样在一个命令中实现这一点。 如果你真的想使用TF函数我可以建议一些像:

x = tf.Variable([
    [1,2,3,1],
    [0,0,0,0],
    [1,3,5,7],
    [0,0,0,0],
    [3,5,7,8]])

y = tf.Variable([0,0,0,0])
condition = tf.equal(x, y)
indices = tf.where(condition)

这将产生以下结果:

[[1 0]
 [1 1]
 [1 2]
 [1 3]
 [3 0]
 [3 1]
 [3 2]
 [3 3]]

如果您只想获得零线,则可以使用以下内容:

row_wise_sum = tf.reduce_sum(tf.abs(x),1)
select_zero_sum = tf.where(tf.equal(row_wise_sum,0))

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(select_zero_sum))

结果是:

[[1]
 [3]]

答案 1 :(得分:0)

它也可以更容易地完成:

    g = tf.Graph()
    with g.as_default():
       a = tf.placeholder(dtype=tf.float32, shape=[3, 4])
       b = tf.placeholder(dtype=tf.float32, shape=[1, 4])

       res = tf.not_equal(a, b)
       res = tf.reduce_sum(tf.cast(res, tf.float32), 1)
       res = tf.where(tf.equal(res1, [0.0]))[0]

    with tf.Session(graph=g) as sess:
    sess.run(tf.global_variables_initializer())
    dict_ = {
        a:np.array([[2.0,6.0,3.0,2.0],
                   [1.0,8.0,32.0,1.0],
                   [1.0,8.0,3.0,11.0]]),
        b:np.array([[1.0,8.0,3.0,11.0]])
    }

    print(sess.run(res, feed_dict=dict_))
    :[2]