如何选择将nd存储在另一个Tensor中的nd张量的行?

时间:2019-03-28 08:51:04

标签: tensorflow

我正在尝试沿第一个维度切片形状为(?, 32, 32)的张量。我必须选择两行索引存储在形状为(1, 2)的另一个张量中。我想要类似numpy的array[list of indexes, :, :]

我该怎么办?我需要此操作来计算model_fn函数内部的损失,并将其传递给我的自定义Tensorflow Estimator。

1 个答案:

答案 0 :(得分:0)

我用tf.gather_nd解决了。我用以下命令重塑了包含索引的张量:

ids = tf.reshape(tensor_with_indexes, shape=(-1, 1))

然后我申请:

new_tensor = tf.gather_nd(original_tensor, ids)