LSTM,GRU与word嵌入问题

时间:2016-09-24 14:56:36

标签: python machine-learning keras lstm

我一直试图让我的模型工作,但它已经3天了,我无法得到它。我正致力于从一个名为“方面提取”的句子中提取单词 问题如下: 我有(3044,80,145)矩阵作为输入 3044-句子数。 80-句子中的最大单词数 word_embeddings的145维(包括POS标签等)

和(3044,80)作为输出,其中: 80-单热矢量,其中1出现在一个方面的位置。

我已经尝试过LSTM和GRU,但每次都会报告以下错误(差不多)。请指导。

TypeError: ('The following error happened while compiling the node', forall_inplace,cpu,scan_fn}(TensorConstant{80}, InplaceDimShuffle{1,0,2}.0, IncSubtensor{InplaceSet;:int64:}.0, TensorConstant{80}, gru_1_U_z, gru_1_U_r, gru_1_U_h), '\n', "Inconsistency in the inner graph of scan 'scan_fn' : an input and an output are associated with the same recurrent state and should have the same type but have type 'TensorType(float32, col)' and 'TensorType(float32, matrix)' respectively.")

我的代码是:

train_inp = pickle.load(open("train_inp_145.pkl", "rb"))
train_out = pickle.load(open("train_out_145.pkl", "rb"))

train_inp = train_inp.reshape(3044, 80, 145).astype('float32')
train_out = train_out.reshape(3044, 80, 1)

model = Sequential()

model.add(GRU(1, return_sequences=True, input_shape=(80, 145)))
model.add(TimeDistributedDense(1))
model.add(Activation("softmax"))

model.compile(loss='mse',
              optimizer='rmsprop',
              metrics=['accuracy'])

print (model.summary())
model.fit(train_inp, train_out, validation_split=0.2,
          batch_size=100,
          nb_epoch=2
          )

模型摘要:

Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
gru_1 (GRU)                      (None, 80, 1)         441         gru_input_1[0][0]                
____________________________________________________________________________________________________
timedistributeddense_1 (TimeDistr(None, 80, 1)         2           gru_1[0][0]                      
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 80, 1)         0           timedistributeddense_1[0][0]     
====================================================================================================
Total params: 443

0 个答案:

没有答案