我在Keras的自定义图层中初始化张量A
,其中batchSize是占位符:
A = K.zeros([batchSize, 2, 2 ,2])
我还初始化了一个大小为B
的numpy数组[3,2,2,2]
。我想从[i,2,2,2]
B
随机选择i = 0,1,2
大小数组,并将其分配给A的第一个维度并重复此batchSize次数。
由于我无法显式循环遍历batchSize,我尝试了tensorflow.map_fn,如下所示:
ANew = tf.map_fn(lambda x: K.variable(B[np.random.randint(0,3,size=(1)).tolist()[0],:,:,:],
A, dtype=’float’, back_prop=False, infer_shape=True)
这会产生ANew
张量。但是,np.random.randint
看起来只被调用一次;结果,我总是选择相同的索引。如何修改代码以便np.random.randint(0,3,size=(1)).tolist()[0]
被称为batchSize次?
答案 0 :(得分:1)
您正在寻找K.gather
。
A = K.gather(B, indices_list)