Tensorflow如何检查张量行是否仅为零?

时间:2016-03-03 22:08:34

标签: python tensorflow

我正在训练一个简单的网络来预测单个对象的边界框坐标。然而,有些图片没有找到对象。由于网络总是进行预测,因此它还预测0和1之间的置信度值,其应指示图片中存在对象的概率。我对预测的张量称为logits,它是(batch_size, 5)张量(置信度,x,y,宽度和高度)。同样地,labels张量也是(batch_size, 5)

以前我只训练过总是有物体的图像,所以我基本上可以做到

loss = tf.l2_loss(logits - labels)

我想开始训练没有物体的照片,当照片中没有物体时,我不希望网络因其预测的任何坐标而受到惩罚。在这种情况下,重要的是置信度值,它应接近0(无对象)。

我应该如何构建标签和损失功能来实现这一目标?我可以将没有对象的图像标签设置为全零,但是如何检查特定行是否仅为零?在这种情况下,logits中相应的行也需要设置为零(置信度值除外!),以便因坐标而产生的损失也为零。

2 个答案:

答案 0 :(得分:1)

查找张量行是否全为零的方法:

import tensorflow as tf

image = tf.fill([8,8], 0)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
image_row = tf.slice(image, [1,0], [1, -1])
total = tf.reduce_sum(tf.abs(image_row))
is_all_zero = tf.equal(total, 0)
print sess.run([total, is_all_zero, image_row])

结果:

[0, True, array([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)]

答案 1 :(得分:0)

您可以使用tf.count_nonzero()检查张量是否全部为零。您可以查看指南here