索引Keras Tensor

时间:2018-06-02 02:41:04

标签: python tensorflow keras keras-layer

我的Keras功能模型的输出层是维度x的张量(None, 1344, 2)。我希望从n < 1344的第二维中提取x个条目,并创建一个大小为y的新张量(None, n, 2)

通过简单地访问n来提取x[:, :n,:]个连续条目似乎很简单,但如果n索引是非连续的,那么(看似很难)。在Keras有这么干净的方式吗?

到目前为止,这是我的方法。

实验1 (切片张量,连续索引,有效):

print('My tensor shape is', K.int_shape(x)) #my tensor 
(None, 1344, 2) # as printed in my code
print('Slicing first 5 entries, shape is', K.int_shape(x[:, :5, :]))
(None, 5, 2) # as printed in my code, works!

实验2 (在任意索引处索引张量,失败)

print('My tensor shape is', K.int_shape(x)) #my tensor 
(None, 1344, 2) # as printed in my code
foo = np.array([1,2,4,5,8])
print('arbitrary indexing, shape is', K.int_shape(x[:,foo,:]))

Keras返回以下错误:

ValueError: Shapes must be equal rank, but are 1 and 0
From merging shape 1 with other shapes. for 'strided_slice_17/stack_1' (op: 
'Pack') with input shapes: [], [5], [].

实验3 (张量流后端功能) 我也试过K.backend.gather,但它的用法不清楚,因为1)Keras文档声明索引应该是整数的张量,如果我的目标是在numpy.where中提取条目,则没有Keras等价于x {1}}满足某个条件且2)K.backend.gather似乎从axis = 0中提取条目,而我想从x的第二维中提取。

1 个答案:

答案 0 :(得分:1)

您正在寻找将根据索引数组进行索引的tf.gather_nd

# From documentation
indices = [[0, 0], [1, 1]]
params = [['a', 'b'], ['c', 'd']]
output = ['a', 'd']

要在Keras模型中使用它,请确保将其包裹在Lambda之类的图层中。