我正在尝试为项目实现波束搜索解码器。目前,我正在关注TensorFlow教程,该教程来自:
我是深度学习和TensorFlow的新手,不确定我是否以正确的方式进行操作。
我已编辑代码以适应波束搜索解码器的更改。我已经用LSTM层替换了GRU层,并添加了光束解码器层。
def call(self, x, hidden, hidden2, enc_output):
# enc_output shape == (batch_size, max_length, hidden_size)
# hidden shape == (batch_size, hidden size)
# hidden_with_time_axis shape == (batch_size, 1, hidden size)
# we are doing this to perform addition to calculate the score
hidden_with_time_axis = tf.expand_dims(hidden, 1)
hidden_with_time_axis2 = tf.expand_dims(hidden2, 1)
# score shape == (batch_size, max_length, 1)
# we get 1 at the last axis because we are applying tanh(FC(EO) + FC(H)) to self.V
score = self.V(
tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis) + self.W3(hidden_with_time_axis2)))
# attention_weights shape == (batch_size, max_length, 1)
attention_weights = tf.nn.softmax(score, axis=1)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = attention_weights * enc_output
context_vector = tf.reduce_sum(context_vector, axis=1)
# x shape after passing through embedding == (batch_size, 1, embedding_dim)
x = self.embedding(x)
# x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
output, h, c = self.lstm(x)
logits = self.fc(output)
beam = tf.contrib.seq2seq.BeamSearchDecoder(self.cell, self.embedding, self.start_tokens, self.end_token,
tf.contrib.rnn.LSTMStateTuple(
tf.contrib.seq2seq.tile_batch(h, multiplier=self.beam_size),
tf.contrib.seq2seq.tile_batch(c,
multiplier=self.beam_size)),
self.beam_size)
output, beamOp , _ = tf.contrib.seq2seq.dynamic_decode(
beam, output_time_major=True, maximum_iterations=MAX_SEQUENCE_LENGTH)
predicted_ids = tf.transpose(tf.cast(output.predicted_ids[:, :, 0], tf.float32))
beamOp_h = beamOp[0][0][:, 0]
beamOp_c = beamOp[0][1][:, 0]
return predicted_ids, logits, beamOp_h, beamOp_c, attention_weights
现在,当我开始训练模型时,我的内存已不足。我认为我在这里犯了一个致命的错误,但无法弄清楚这是什么。有人可以帮我吗?
我尝试在具有GTX 1060的Windows本地计算机以及具有V100 GPU的AWS ubuntu计算机上使用TensorFlow GPU进行培训。两种环境都有tf版本1.13