给出一个形状为(?,5,5)的输入张量,我需要通过对形状为(120,5,2)的索引张量指定的元素求和来找到每个示例的最大和。索引张量列出了将示例的5x5矩阵求和的120种方法。 例如:
Input tensor (?,5,5):
[
[
[0,1,0,0,0],
[0,0,0,0,1],
[1,0,0,0,0],
[0,0,0,1,0],
[0,0,1,0,0]
],
[
...
],
...
]
Index tensor(120,5,2):
[
[
[0,1],
[1,4],
[2,2],
[3,0],
[4,3]
],
[
...
],
...
]
在这里,第一求和的结果为1 + 1 + 0 + 0 + 0 = 2。 我需要找到每个示例的索引数组给出的所有120种方式的最大和。
在numpy中,我将对整数索引数组使用高级索引,但不幸的是tf不支持此功能。我找到了tf.gather_nd,但是似乎我这个函数假设我知道我不知道的批次中每个示例的索引。
答案 0 :(得分:1)
解决了。 诀窍是移置轴。这样,未知尺寸可以推到最后,并且collect_nd可以选择未知尺寸之前的所有切片。
这是完整的代码,如果有人在乎...
def permute(a, l, r):
if l==r:
yield list(zip([0,1,2,3,4],a))
else:
for i in range(l,r+1):
a[l], a[i] = a[i], a[l]
yield from permute(a, l+1, r)
a[l], a[i] = a[i], a[l]
def multi_class_acc_positions(pred, target, input):
pred_5x5 = tf.reshape(pred, [-1, 5, 5])
target_5x5 = tf.reshape(target, [-1, 5, 5])
pred_5x5_T = tf.transpose(pred_5x5, (1,2,0))
all_perms = tf.constant(list(permute([0,1,2,3,4],0,4)))
selected_elemens_per_example = tf.gather_nd(pred_5x5_T, all_perms)
sums_per_example = tf.reduce_sum(selected_elemens_per_example, axis=1)
best_perm_per_example_index = tf.argmax(sums_per_example, axis=0)
best_perms = tf.gather_nd(all_perms, best_perm_per_example_index[:,tf.newaxis])[:,:,1]
pred_5x5_one_hot = tf.reshape(tf.one_hot(best_perms, depth=5), (-1, 5, 5))
correct_prediction = tf.equal(tf.argmax(pred_5x5_one_hot, axis=2), tf.argmax(target_5x5, axis=2))
all_correct = tf.reduce_min(tf.cast(correct_prediction, tf.float32), 1)
acc = tf.reduce_mean(all_correct)
return acc