我正在尝试删除教师在解码器动态rnn中强制进行推理的方法,但无法做到。 我尝试了所有方法,但是有什么方法可以做而无需太多更改体系结构。使用代码帮助我
类架构:
def __init__(self,):
global vocab
self.num_units = 200
self.embed_size = 100
self.batch_size = 32
self.vocab = vocab
self.encoder_embedding = tf.Variable(tf.random_uniform((len(vocab), self.embed_size), -1, 1),trainable=True)
self.decoder_embedding = tf.Variable(tf.random_uniform((len(vocab), self.embed_size), -1, 1),trainable=True)
self.forward_encoder_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
self.backward_encoder_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
self.forward_decoder_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
self.backward_decoder_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
self.summary_fw_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
self.summary_bw_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
self.build_model()
def encoder(self,input_):
with tf.variable_scope("Encoder"):
(output_fw, output_bw), (output_state_fw, output_state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.forward_encoder_lstm_cell,
cell_bw=self.backward_encoder_lstm_cell,
inputs=input_,
dtype=tf.float32)
return (output_fw, output_bw), (output_state_fw, output_state_bw)
def decoder(self,input_,output_state_fw,output_state_bw):
with tf.variable_scope("Decoder"):
(dec_output_fw, dec_output_bw), (dec_output_state_fw, dec_output_state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.forward_decoder_lstm_cell,
cell_bw=self.backward_decoder_lstm_cell,
inputs=input_,
dtype=tf.float32,
initial_state_bw=output_state_bw,
initial_state_fw=output_state_fw)
return (dec_output_fw, dec_output_bw), (dec_output_state_fw, dec_output_state_bw)
def summary_decoder()
def build_model(self):
self.input_data = tf.placeholder(tf.int64, [None, None], name="input_data")
input_embed = tf.nn.embedding_lookup(self.encoder_embedding, self.input_data)
(output_fw, output_bw), (self.output_state_fw, self.output_state_bw) = self.encoder(input_embed)
self.out_data = tf.placeholder(tf.int64, [None, None], name="out_data")
self.target_embed = tf.nn.embedding_lookup(self.decoder_embedding, self.out_data)
(self.dec_output_fw, self.dec_output_bw), (self.dec_output_state_fw, self.dec_output_state_bw) = self.decoder(self.target_embed,self.output_state_fw,self.output_state_bw)
self.decoder_findal_output = tf.concat([self.dec_output_fw, self.dec_output_bw], axis=2)
self.targets = tf.placeholder(tf.int64, [None, None], name="targets")
self.lg = tf.contrib.layers.fully_connected(self.decoder_findal_output, num_outputs=len(self.vocab), activation_fn=None)
self.logits = tf.arg_max(self.lg,dimension=2)