在tensorflow map_fn中使用随机生成的整数作为张量索引

时间:2018-04-19 16:46:34

标签: python numpy tensorflow keras

我在Keras的自定义图层中初始化张量A,其中batchSize是占位符:

A = K.zeros([batchSize, 2, 2 ,2])

我还初始化了一个大小为B的numpy数组[3,2,2,2]。我想从[i,2,2,2] B随机选择i = 0,1,2大小数组,并将其分配给A的第一个维度并重复此batchSize次数。

由于我无法显式循环遍历batchSize,我尝试了tensorflow.map_fn,如下所示:

ANew = tf.map_fn(lambda x: K.variable(B[np.random.randint(0,3,size=(1)).tolist()[0],:,:,:],
                 A, dtype=’float’, back_prop=False, infer_shape=True)

这会产生ANew张量。但是,np.random.randint看起来只被调用一次;结果,我总是选择相同的索引。如何修改代码以便np.random.randint(0,3,size=(1)).tolist()[0]被称为batchSize次?

1 个答案:

答案 0 :(得分:1)

您正在寻找K.gather

A = K.gather(B, indices_list)