这个numpy数组操作的等效tensorflow操作是什么?

时间:2018-08-07 02:37:29

标签: python tensorflow

例如,y_pred是一个numpy数组,我需要此操作。

result = []
for i in y_pred:
    i = np.where(i == i.max(), 1, 0)
    result.append(i)

y_pred中每一行的最大数量将变为1,其余各行将变为0

如果y_pred是张量,如何执行此操作?

1 个答案:

答案 0 :(得分:0)

不确定是否已经有矢量化的功能可以一起完成所有这些操作。

假设这是您的数组。

a = tf.Variable([[10,11,12,13],
             [13,14,16,15],
             [18,16,17,15],
             [0,4,3,2]], name='a')

top = tf.nn.top_k(a,1)给出每一行中最大值的索引。现在我们从 ones acopy = tf.Variable(tf.zeros((4,4),tf.int32), name='acopy') 开始,它的形状与原始数组相同。

在那之后,我们将找到最大值的位置替换为零后,将行缝合在一起。

import tensorflow as tf

a = tf.Variable([[10,11,12,13],
                 [13,14,16,15],
                 [18,16,17,15],
                 [0,4,3,2]], name='a')

acopy = tf.Variable(tf.zeros((4,4),tf.int32), name='acopy')

top = tf.nn.top_k(a,1)

sess = tf.Session()

sess.run(tf.global_variables_initializer())

values,indices = sess.run(top)

shapeofa = a.eval(sess).shape

print(indices)

sess.run(tf.global_variables_initializer())

for i in range(0, shapeofa[0] ) :
    oldrow = tf.gather(acopy, i)
    index = tf.squeeze(indices[i])
    b = sess.run( tf.equal(index, 0) )
    if not b  :
        o = oldrow[0:index]
        newrow = tf.concat([o, tf.constant([1])], axis=0)
        c = sess.run(tf.equal(index, (tf.size(oldrow) - 1)))
        if not c:
            newrow = tf.concat([newrow, oldrow[(index+1):(tf.size(oldrow))]], axis=0)
    else :
        o1 = oldrow[1:tf.size(oldrow)]
        newrow = tf.concat([tf.constant([1]), o1], axis=0)
    acopy = tf.scatter_update(acopy, i, newrow)
    print(sess.run(acopy))

输出是这个。

  

[[0 0 0 1]

     

[0 0 1 0]

     

[1 0 0 0]

     

[0 1 0 0]]

您可以自己测试一下。