tf.where()不会产生与np.where()相同的结果

时间:2017-10-25 03:34:11

标签: python numpy tensorflow

我使用相同的逻辑代码将tensorflow与numpy进行比较。

在实施tf.where时,我无法获得与np.where相同的结果 下面的代码或用法有什么问题?

数据:

X_batch = np.concatenate([np.arange(10).reshape(1, -1) for i in range(10)], axis=0)

tensorflow tf.where玩具代码:

X = tf.placeholder(dtype=tf.int32, shape=[10, 10])

with tf.Session() as sess:
    print(sess.run(tf.where(X > 5, tf.zeros([10, 10], dtype=tf.int32), 
                            X), feed_dict={X: X_batch}))

numpy np.where玩具代码:

np.where(X_batch > 5, np.zeros([10,10]), X_batch)

代码有一些拼写错误。纠正已经完成

1 个答案:

答案 0 :(得分:0)

我编辑了代码。输入到tf.where()应该与np.where()相同。因此,对于tf.where()的参数,你应该给出零10 * 10矩阵和x_batch作为你的np.where()方法参数的输入。

XSRF-TOKEN=eyJpdiI6IkJtSmdiYUFVVXJxK2Q4bkxvaGdqTXc9PSIsInZhbHVlIjoid3ZBQXBaK2dIVE5DSnBDVnRFZW5vS0hcL1dZaDRENTFNWm1TZjFtT3paU1NicDEwV0RydGM0MGp1Q3Qrdk9rcEZ0S2RDdERBOXNheDl5U0xVUFpwMGd3PT0iLCJtYWMiOiJkMjQ3M2Q3MTlhYmE2MDhkMTg4N2YzZTM5NzNmMjk2MGM1NTEyZWYzNjlhZTgxMjcwOWExY2MyMjIwZDQ5ZjQ0In0%3D; ARRAffinity=7459802e08037292d7e12c172fb9496af377eac0039643a96805d4f72b58b44f

希望这有帮助。