keras concatenate嵌入层获取索引错误

时间:2018-01-08 02:45:14

标签: python keras

我正在进行多重嵌入,需要将所有嵌入层连接在一起进行培训。但是,我一直得到索引[1,0] = 7不在[0.7]错误中。

这是我做的:

models = []

i0 = Input(shape=(1,),name='model_store')

model_store = Embedding(1115,10,input_length=1)(i0)

model_store = Reshape(target_shape=(10,))(model_store)

models.append(model_store)

i1 = Input(shape=(1,),name='model_dow')

model_dow = Embedding(7,6,input_length=1)(i1)

model_dow = Reshape(target_shape=(6,))(model_dow)

models.append(model_dow)

i2 = Input(shape=(1,),name='model_promo')

model_promo = Dense(1,input_dim=1)(i2)

models.append(model_promo)

# there are 8 embedding and 3 dense layers in models.

# then, I do:

net = Concatenate()(models)

net = Dense(1000,kernel_initializer='uniform',activation='relu')(net)

# another dense layer

output = Dense(1,activation='relu')(net)

model = Model(inputs = [i0,i1,i2,...i10],outputs = output)

model.compile(loss='mean_absolute_error',optimizer='adam')

但是当我做model.fit()时,我得到的索引[] =不在[)错误中。

进入i0,i1,...,i10的输入类似于数组([[1],[2],[3],...]),所有长度为1的输入。

我还试图用Flatten()图层替换Reshape()图层,但是得到了同样的错误。

有人,请帮忙。

1 个答案:

答案 0 :(得分:0)

好吧,我发现了问题。

我没有以正确的形状提供数据。在Sequential API中,multiplue网络输入的输入数据应该是ndarrays列表(dict也可以工作)。虽然它说ndarray列表仍然适用于Functional API,但它在我的情况下并不起作用,可能是由于某些订单问题。 我使用输入名称(' model_store',' model_dow' ...)作为键使用ndarrys字典,并且它有效。