Keras k_gather函数

时间:2018-05-22 17:48:34

标签: tensorflow keras

目前我正试图让Keras后端函数k_gather在R中工作。到目前为止还没有运气。我只能找到有关tensorflow收集功能的适当文档。如果我遵循这个文档,下面的代码应该提取张量a的(1,1,1)-entry。

library(keras)

a =  k_constant(c(1L, 2L,3L,4L), dtype = 'int32' , shape = c(1L, 1L, 4L ))
c = k_constant(c(1L, 0L,0L,0L), dtype = 'int32' , shape = c(1L, 1L, 4L ))
out = k_gather(a , indices =  c )
sess$run(out)

然而,它似乎并没有这样的方式。当我运行它时,我收到错误

Error in py_call_impl(callable, dots$args, dots$keywords) : 
InvalidArgumentError: indices[0,0,0] = 1 is not in [0, 1)

错误并不代表我,因为out的形状似乎是

shape=(1, 1, 4, 1, 4)

而不仅仅是

 shape=(1, 1, 4)

它是如何运作的。如何提取我可爱张量的第一个组成部分?

1 个答案:

答案 0 :(得分:0)

我找到了解决这个k_gather问题的方法。我现在使用以下程序来置换张量a

library(keras)

a =  k_constant(c(1L, 2L,3L,4L), dtype = 'int32' , shape = c(1L, 1L, 4L ))

a_1 =  a[,,1:2]
a_2 =  a[,,3:4]

a_new = k_concatenate( list(a_2, a_1))

sess = k_get_session()

sess$run(a_new)

它没有解决我的问题k_gather,但它做了我想要的。