让X
为张量,其第一维未知(批量大小),第二维和第三维已知,n,m
。设Y
为相同维度的张量,这是X
的掩码,即对于批处理中的每个样本,它包含1
,其中X[b,n,m]
包含真值0
1}}它只是填充。
我想在X
的行/列上运行汇总操作,受掩码限制。即,如果某些n
样本中的某些行X
包含来自某个点n_0
的零,我不希望计算包含它。
虽然我可以手动为reduce_mean
或reduce_min
等操作解决此问题,但如果矩阵填充为零,我就无法使用reduce_prod
...
有没有办法使用矢量蒙版执行reduce_XXX
类型的Tensorflow操作?
感谢。
答案 0 :(得分:3)
您可以使用动态分区使用掩码值将数据分成两个张量
data = tf.constant([0, 1, 2, 3])
mask = tf.cast(data>0, tf.int32)
bad_data, good_data = tf.dynamic_partition(data, mask, 2)
sess.run(tf.reduce_prod(good_data))