只乘以张量流数组的某些列

时间:2018-04-05 08:43:48

标签: matrix tensorflow slice

我正在修改我的一个物体检测神经网络的损失函数。我基本上有两个数组;

y_true:预测标签。 tf形状的张量(x,y,z) y_pred:预测值。 tf形状的张量(x,y,z) - x维度是批量大小,y维度是图像中预测对象的数量,z维度包含类的单热编码以及边界所说的课程。

现在回答真正的问题:我想要做的基本上是将y_pred中的前5个值z值与y_true中的前5个z值相乘。所有其他值应保持不受影响。在numpy中它非常直接;

y_pred[:,:,:5] *= y_true[:,:,:5]

我发现在张量流中很难做到,因为我无法为原始张量赋值,我想保持所有其他值相同。我如何在tensorflow中执行此操作?

1 个答案:

答案 0 :(得分:1)

从v1.1开始,Tensorflow涵盖了类似Numpy的索引,请参阅Tensor.getitem

import tensorflow as tf

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10], [10,20,30,40,50,60,70,80,90,100]]])
    print((y_pred[:,:,:5] * y_true[:,:,:5]).eval()) 
    # [[[   1    4    9   16   25]
    #   [ 100  400  900 1600 2500]]]

评论后编辑:

现在,问题是" * ="部分即项目分配。这在Tensorflow中不是一个简单的操作。但是,在您的情况下,可以使用tf.concattf.where轻松解决此问题(tf.dynamic_partition + tf.dynamic_stitch可用于更复杂的情况。)

在下面找到两个第一个解决方案的快速实现。

使用Tensor.getitem和tf.concat的解决方案:

import tensorflow as tf

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])

    # tf.where can't apply the condition to any axis (see doc).
    # In your case (condition on 2nd axis), we need either to manually broadcast the
    # condition tensor, or transpose the target tensors.
    # Here is a quick demonstration with the 2nd solution:

    y_pred_edit = y_pred[:,:,:5] * y_true[:,:,:5]
    y_pred_rest = y_pred[:,:,4:]

    y_pred = tf.concat((y_pred_edit, y_pred_rest), axis=2)
    print(y_pred.eval())
    # [[[ 1  4  9 16 25  6  7  8  9 10]]]

使用tf.where:

的解决方案
import tensorflow as tf

def select_n_fist_indices(n, batch_size):
    """ Return a list of length batch_size with the n first elements True
        and the rest False, i.e. [*[[True] * n], *[[False] * (batch_size - n)]]. 
    """
    n_ones = tf.ones((n))
    rest_zeros = tf.zeros((batch_size - n))
    indices = tf.cast(tf.concat((n_ones, rest_zeros), axis=0), dtype=tf.bool)

    return indices

with tf.Session() as sess:
    y_pred = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])
    y_true = tf.constant([[[1,2,3,4,5,6,7,8,9,10]]])

    # tf.where can't apply the condition to any axis (see doc).
    # In your case (condition on 2nd axis), we need either to manually broadcast the 
    # condition tensor, or transpose the target tensors.
    # Here is a quick demonstration with the 2nd solution:
    y_pred_tranposed = tf.transpose(y_pred, [2, 0, 1])
    y_true_tranposed = tf.transpose(y_true, [2, 0, 1])

    edit_indices = select_n_fist_indices(5, tf.shape(y_pred_tranposed)[0])

    y_pred_tranposed = tf.where(condition=edit_indices, 
                                x=y_pred_tranposed * y_true_tranposed, y=y_pred_tranposed)

    # Transpose back:  
    y_pred = tf.transpose(y_pred_tranposed, [1, 2, 0])
    print(y_pred.eval())
    # [[[ 1  4  9 16 25  6  7  8  9 10]]]