我正在进行多重嵌入,需要将所有嵌入层连接在一起进行培训。但是,我一直得到索引[1,0] = 7不在[0.7]错误中。
这是我做的:
models = []
i0 = Input(shape=(1,),name='model_store')
model_store = Embedding(1115,10,input_length=1)(i0)
model_store = Reshape(target_shape=(10,))(model_store)
models.append(model_store)
i1 = Input(shape=(1,),name='model_dow')
model_dow = Embedding(7,6,input_length=1)(i1)
model_dow = Reshape(target_shape=(6,))(model_dow)
models.append(model_dow)
i2 = Input(shape=(1,),name='model_promo')
model_promo = Dense(1,input_dim=1)(i2)
models.append(model_promo)
# there are 8 embedding and 3 dense layers in models.
# then, I do:
net = Concatenate()(models)
net = Dense(1000,kernel_initializer='uniform',activation='relu')(net)
# another dense layer
output = Dense(1,activation='relu')(net)
model = Model(inputs = [i0,i1,i2,...i10],outputs = output)
model.compile(loss='mean_absolute_error',optimizer='adam')
但是当我做model.fit()时,我得到的索引[] =不在[)错误中。
进入i0,i1,...,i10的输入类似于数组([[1],[2],[3],...]),所有长度为1的输入。
我还试图用Flatten()图层替换Reshape()图层,但是得到了同样的错误。
有人,请帮忙。
答案 0 :(得分:0)
好吧,我发现了问题。
我没有以正确的形状提供数据。在Sequential API中,multiplue网络输入的输入数据应该是ndarrays列表(dict也可以工作)。虽然它说ndarray列表仍然适用于Functional API,但它在我的情况下并不起作用,可能是由于某些订单问题。 我使用输入名称(' model_store',' model_dow' ...)作为键使用ndarrys字典,并且它有效。