我正在研究一个使用tf.contrib.seq2seq
模块的seq2seq模型,我希望将其构建为与图和渴望执行兼容。据我所知,tf.contrib.seq2seq
尚不能在急切执行中起作用(如果我错了,请纠正我)。所以我想将其嵌入tfe.defun()
中,但出现以下错误:
RuntimeError: tf.device does not support functions when eager execution is enabled.
在下一行:
final_outputs, final_state, final_sequence_lengths = \
tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True
我使用tfe.defun()
的两个部分:
def train_decoder_model(self, function_encoder_outputs, decoder_inputs, target_sequence_length):
with tf.name_scope('embedding'):
# The embedding layer expects integer instead of one-hot encodings.
decoder_inputs_ints = tf.argmax(decoder_inputs, axis=-1)
# Perform the embedding on the decoder input.
decoder_embedding = tf.nn.embedding_lookup(self._emb_matrix, decoder_inputs_ints)
with tf.name_scope('decoder'):
target_sequence_length = tf.cast(target_sequence_length, tf.int32)
def graph_decoder():
helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding, target_sequence_length)
decoder = tf.contrib.seq2seq.BasicDecoder(self.decoder_cell, helper,
function_encoder_outputs, self.decoder_dense)
final_outputs, final_state, final_sequence_lengths = \
tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True)
return final_outputs
graph_func = tfe.defun(graph_decoder)
final_outputs = graph_func()
return final_outputs
def loss(self, labels, logits, doc_length):
def graph_loss():
masks = tf.sequence_mask(doc_length, tf.reduce_max(doc_length), dtype=tf.float32, name='masks')
return tf.contrib.seq2seq.sequence_loss(logits, tf.argmax(labels, -1), masks)
graph_func = tfe.defun(graph_loss)
return graph_func()
有什么线索可以解决这个问题吗?