卡在tensorflow高级索引上

时间:2018-10-03 09:14:19

标签: python tensorflow

给出一个形状为(?,5,5)的输入张量,我需要通过对形状为(120,5,2)的索引张量指定的元素求和来找到每个示例的最大和。索引张量列出了将示例的5x5矩阵求和的120种方法。 例如:

Input tensor (?,5,5):
[
  [
    [0,1,0,0,0],
    [0,0,0,0,1],
    [1,0,0,0,0],
    [0,0,0,1,0],
    [0,0,1,0,0]
  ],
  [
    ...
  ],
  ...
]

Index tensor(120,5,2):
[
  [
    [0,1], 
    [1,4], 
    [2,2], 
    [3,0], 
    [4,3]  
  ],
  [
    ...
  ],
...
]

在这里,第一求和的结果为1 + 1 + 0 + 0 + 0 = 2。 我需要找到每个示例的索引数组给出的所有120种方式的最大和。

在numpy中,我将对整数索引数组使用高级索引,但不幸的是tf不支持此功能。我找到了tf.gather_nd,但是似乎我这个函数假设我知道我不知道的批次中每个示例的索引。

1 个答案:

答案 0 :(得分:1)

解决了。 诀窍是移置轴。这样,未知尺寸可以推到最后,并且collect_nd可以选择未知尺寸之前的所有切片。

这是完整的代码,如果有人在乎...

def permute(a, l, r):
    if l==r:
        yield list(zip([0,1,2,3,4],a))
    else:
        for i in range(l,r+1):
            a[l], a[i] = a[i], a[l]
            yield from permute(a, l+1, r)
            a[l], a[i] = a[i], a[l]

def multi_class_acc_positions(pred, target, input):
    pred_5x5 = tf.reshape(pred, [-1, 5, 5])
    target_5x5 = tf.reshape(target, [-1, 5, 5])
    pred_5x5_T = tf.transpose(pred_5x5, (1,2,0))
    all_perms = tf.constant(list(permute([0,1,2,3,4],0,4)))
    selected_elemens_per_example = tf.gather_nd(pred_5x5_T, all_perms)
    sums_per_example = tf.reduce_sum(selected_elemens_per_example, axis=1)
    best_perm_per_example_index = tf.argmax(sums_per_example, axis=0)
    best_perms = tf.gather_nd(all_perms, best_perm_per_example_index[:,tf.newaxis])[:,:,1]
    pred_5x5_one_hot = tf.reshape(tf.one_hot(best_perms, depth=5), (-1, 5, 5))
    correct_prediction = tf.equal(tf.argmax(pred_5x5_one_hot, axis=2), tf.argmax(target_5x5, axis=2))
    all_correct = tf.reduce_min(tf.cast(correct_prediction, tf.float32), 1)
    acc = tf.reduce_mean(all_correct)
    return acc