在Tensorflow中,如何重建具有相同标签的张量?

时间:2018-11-16 17:41:34

标签: python tensorflow

我想实现一个重构具有相同标签的张量的函数。例如,我们有一个张量特征BxNxC,其标签为BxN,B为批处理大小,C为特征的尺寸,标签的范围为K,我想要的输出为BxKxC。具有相同标签的要素将组合在一起,并输出具有特定标签的均值要素。

有人知道如何实现此功能吗?

我已经实现了一些,但是看起来很丑。

label_class = 12
for batch in range(part_pred.get_shape().as_list()[0]):                                                                                               
    batch_label = part_label[batch]
    batch_pred = mesh_point_net[batch,...]    
    for label in range(label_class):
        condition = tf.equal(batch_label, label)
        index = tf.where(condition) 
        pred = tf.reduce_mean(tf.gather(batch_pred, index), axis=0)
        if label == 0:
            batch_mean_pred = pred
        else:
            batch_mean_pred = tf.concat([batch_mean_pred, pred], axis=0)
     if batch == 0:
         mean_pred = tf.expand_dims(batch_mean_pred, axis=0)
     else:
         mean_pred = tf.concat([mean_pred,tf.expand_dims(batch_mean_pred, axis=0)], axis=0) 

0 个答案:

没有答案