RNNCell中的权重

时间:2019-01-20 17:55:10

标签: python tensorflow deep-learning chatbot

我正在集中精力在深度nlp中创建聊天机器人的udemy教程中,而iam则停留在此错误,该错误与rnn中的编码器层有关...

ValueError:尝试让第二个RNNCell使用已经具有权重的变量作用域的权重:“ bidirectional_rnn / fw / multi_rnn_cell / cell_0 / basic_lstm_cell”;并且该单元格未构造为BasicLSTMCell(...,复用= True)。要共享RNNCell的权重,只需在第二次计算中重用它,或者使用参数redirect = True创建一个新的值。

代码是

def encoder_rnn(rnn_inputs, rnn_size, num_layers, keep_prob, sequence_length):
    lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size)
    lstm_dropout = tf.contrib.rnn.DropoutWrapper(lstm, input_keep_prob = keep_prob)
    encoder_cell = tf.contrib.rnn.MultiRNNCell([lstm_dropout] * num_layers)
    _, encoder_state = tf.nn.bidirectional_dynamic_rnn(cell_fw = encoder_cell,
                                                                    cell_bw = encoder_cell,
                                                                    sequence_length = sequence_length,
                                                                    inputs = rnn_inputs,
                                                                    dtype = tf.float32)


 #getting traing and test predictions
training_predictions, test_predictions=seq2seq_model(tf.reverse(inputs,[-1]), targets, keep_probe,
                                                     batch_size,
                                                     sequence_length,len(answerswords2int), len(questionswords2int),
                                                     encoding_embedding_size, decoding_embedding_size,
                                                     rnn_size, num_layers, questionswords2int)

请帮助!

0 个答案:

没有答案