我希望将numpy操作转换为tensorflow。
给定2D ndarray的函数,我想使沿轴0不是最大的所有条目都更新为0。
def rows_to_zero(arr: np.ndarray):
def row_to_zero(row: np.ndarray):
row[row < row.max()] = 0
return row
return np.apply_along_axis(row_to_zero, 0, arr)
In: [[1, 2, 1, 4, 5, 3],
[2, 4, 5, 0, 1, 3],
[3, 5, 3, 6, 7, 1]]
Out: [[0, 0, 0, 0, 5, 0],
[0, 0, 5, 0, 0, 0],
[0, 0, 0, 0, 7, 0]]
我想使用Tensorflow张量编写相同的功能
任何人都可以为这样的事情提供帮助吗?
答案 0 :(得分:1)
您只需要tf.reduce_max
和tf.where
。
import tensorflow as tf
a = tf.constant([[1, 2, 1, 4, 5, 3],
[2, 4, 5, 0, 1, 3],
[3, 5, 3, 6, 7, 1]],tf.float32)
b = tf.reduce_max(a,axis=1,keepdims=True)
result = tf.where(tf.less(a,b),tf.zeros_like(a),a)
with tf.Session() as sess:
print(sess.run(result))
# [[0. 0. 0. 0. 5. 0.]
# [0. 0. 5. 0. 0. 0.]
# [0. 0. 0. 0. 7. 0.]]