我使用相同的逻辑代码将tensorflow与numpy进行比较。
在实施tf.where时,我无法获得与np.where相同的结果 下面的代码或用法有什么问题?
数据:
X_batch = np.concatenate([np.arange(10).reshape(1, -1) for i in range(10)], axis=0)
tensorflow tf.where玩具代码:
X = tf.placeholder(dtype=tf.int32, shape=[10, 10])
with tf.Session() as sess:
print(sess.run(tf.where(X > 5, tf.zeros([10, 10], dtype=tf.int32),
X), feed_dict={X: X_batch}))
numpy np.where玩具代码:
np.where(X_batch > 5, np.zeros([10,10]), X_batch)
代码有一些拼写错误。纠正已经完成
答案 0 :(得分:0)
我编辑了代码。输入到tf.where()应该与np.where()相同。因此,对于tf.where()的参数,你应该给出零10 * 10矩阵和x_batch作为你的np.where()方法参数的输入。
XSRF-TOKEN=eyJpdiI6IkJtSmdiYUFVVXJxK2Q4bkxvaGdqTXc9PSIsInZhbHVlIjoid3ZBQXBaK2dIVE5DSnBDVnRFZW5vS0hcL1dZaDRENTFNWm1TZjFtT3paU1NicDEwV0RydGM0MGp1Q3Qrdk9rcEZ0S2RDdERBOXNheDl5U0xVUFpwMGd3PT0iLCJtYWMiOiJkMjQ3M2Q3MTlhYmE2MDhkMTg4N2YzZTM5NzNmMjk2MGM1NTEyZWYzNjlhZTgxMjcwOWExY2MyMjIwZDQ5ZjQ0In0%3D; ARRAffinity=7459802e08037292d7e12c172fb9496af377eac0039643a96805d4f72b58b44f
希望这有帮助。