我正在使用拼写蜜蜂代码并将其应用于类似的seq2seq任务。我正在努力争取argmax函数的前缀和输出。出于某种原因,对于任何情况,argmax的输出仅返回0.
我更改了很多参数,选择了其他轴..但似乎没什么用。我会很感激任何线索!
我有
parms = {'verbose': 2}
lstm_params = {}
def get_rnn(return_sequences= True):
return LSTM(dim, return_sequences=return_sequences, recurrent_dropout=0.2, implementation=1, dropout=0.2)
inp = Input((maxlen_p,))
x = Embedding(input_vocab_size, 60)(inp)
x = Bidirectional(get_rnn())(x)
x = get_rnn(False)(x)
x = RepeatVector(maxlen)(x)
x = get_rnn()(x)
x = get_rnn()(x)
x = TimeDistributed(Dense(output_vocab_size, activation='softmax'))(x)
model = Model(inp, x)
model.compile(loss='sparse_categorical_crossentropy', optimizer='Adam', metrics=['acc'])
hist=model.fit(input_train, np.expand_dims(labels_train,-1),
validation_data=[input_test, np.expand_dims(labels_test,-1)],
batch_size=128, **parms, epochs=3)
SVG(model_to_dot(model).create(prog='dot', format='svg'))
在此功能中,存在以下问题:
def eval_keras(input):
preds = model.predict(input, batch_size=128)
predict = np.argmax(preds, axis = 2)
return (np.mean([all(real==p) for real, p in zip(labels_test, predict)]), predict)
preds:
[[[9.1350e-02 4.3054e-04 7.0428e-04 ... 5.0601e-04 7.1275e-04 5.8476e-03]
[5.9895e-01 3.5628e-05 7.1672e-05 ... 4.2559e-05 7.7454e-05 3.3954e-03]
[7.2249e-01 1.6146e-05 3.3864e-05 ... 2.0008e-05 3.8247e-05 2.3551e-03] ...
预测:
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
真实(价值):
[ 24 360 60 80 585 73 706 595 766 625 240 284 8 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0]