我想实现一个重构具有相同标签的张量的函数。例如,我们有一个张量特征BxNxC,其标签为BxN,B为批处理大小,C为特征的尺寸,标签的范围为K,我想要的输出为BxKxC。具有相同标签的要素将组合在一起,并输出具有特定标签的均值要素。
有人知道如何实现此功能吗?
我已经实现了一些,但是看起来很丑。
label_class = 12
for batch in range(part_pred.get_shape().as_list()[0]):
batch_label = part_label[batch]
batch_pred = mesh_point_net[batch,...]
for label in range(label_class):
condition = tf.equal(batch_label, label)
index = tf.where(condition)
pred = tf.reduce_mean(tf.gather(batch_pred, index), axis=0)
if label == 0:
batch_mean_pred = pred
else:
batch_mean_pred = tf.concat([batch_mean_pred, pred], axis=0)
if batch == 0:
mean_pred = tf.expand_dims(batch_mean_pred, axis=0)
else:
mean_pred = tf.concat([mean_pred,tf.expand_dims(batch_mean_pred, axis=0)], axis=0)