深度串联在Keras中

时间:2019-03-25 09:23:29

标签: python keras one-hot-encoding

我正在尝试将一个热向量(例如Pytorch在StarGAN中使用Pytorch的实现方式进行深度连接)

input_img = Input(shape = (row, col, chann))
one_hot = Input(shape = (7, ))

我在(it was class indexes)之前偶然发现了相同的问题,因此我使用了RepeatVector + Reshape然后进行了Concatenate。但是我发现,当您要将3D重复为4D(包括batch_num)时,RepeatVector不兼容。

如何在Keras中实现此方法?我发现Upsampling2D可以完成工作,但是我不知道它是否能够在上采样过程中保持单热矢量结构

1 个答案:

答案 0 :(得分:0)

我从How to use tile function in Keras?找到了一个可以使用tile的想法,但是您需要重塑one_hot使其具有与input_img相同的尺寸数

one_hot = Reshape((1, 1, 6))(one_hot)
one_hot = Lambda(K.tile, arguments = {'n' : (-1, row, col, 1)})(one_hot)
model_input = Concatenate()([input_img, one_hot])