我正在尝试沿第一个维度切片形状为(?, 32, 32)
的张量。我必须选择两行索引存储在形状为(1, 2)
的另一个张量中。我想要类似numpy的array[list of indexes, :, :]
。
我该怎么办?我需要此操作来计算model_fn
函数内部的损失,并将其传递给我的自定义Tensorflow Estimator。
答案 0 :(得分:0)
我用tf.gather_nd
解决了。我用以下命令重塑了包含索引的张量:
ids = tf.reshape(tensor_with_indexes, shape=(-1, 1))
然后我申请:
new_tensor = tf.gather_nd(original_tensor, ids)