张量流dynamic_rnn解码器的推断

时间:2019-07-16 04:39:58

标签: tensorflow recurrent-neural-network

我正在尝试删除教师在解码器动态rnn中强制进行推理的方法,但无法做到。 我尝试了所有方法,但是有什么方法可以做而无需太多更改体系结构。使用代码帮助我

类架构:

def __init__(self,):
    global vocab
    self.num_units = 200
    self.embed_size = 100
    self.batch_size = 32
    self.vocab = vocab
    self.encoder_embedding = tf.Variable(tf.random_uniform((len(vocab), self.embed_size), -1, 1),trainable=True)
    self.decoder_embedding = tf.Variable(tf.random_uniform((len(vocab), self.embed_size), -1, 1),trainable=True)
    self.forward_encoder_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
    self.backward_encoder_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
    self.forward_decoder_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
    self.backward_decoder_lstm_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
    self.summary_fw_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
    self.summary_bw_cell = tf.nn.rnn_cell.LSTMCell(self.num_units)
    self.build_model()

def encoder(self,input_):
    with tf.variable_scope("Encoder"): 
        (output_fw, output_bw), (output_state_fw, output_state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.forward_encoder_lstm_cell, 
                                                                         cell_bw=self.backward_encoder_lstm_cell, 
                                                                         inputs=input_,
                                                                         dtype=tf.float32)
        return (output_fw, output_bw), (output_state_fw, output_state_bw)

def decoder(self,input_,output_state_fw,output_state_bw):
    with tf.variable_scope("Decoder"):
        (dec_output_fw, dec_output_bw), (dec_output_state_fw, dec_output_state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.forward_decoder_lstm_cell, 
                                                                         cell_bw=self.backward_decoder_lstm_cell, 
                                                                         inputs=input_,
                                                                         dtype=tf.float32,
                                                                         initial_state_bw=output_state_bw,
                                                                         initial_state_fw=output_state_fw)
        return (dec_output_fw, dec_output_bw), (dec_output_state_fw, dec_output_state_bw)

def summary_decoder()

def build_model(self):
    self.input_data = tf.placeholder(tf.int64, [None, None], name="input_data")
    input_embed = tf.nn.embedding_lookup(self.encoder_embedding, self.input_data)
    (output_fw, output_bw), (self.output_state_fw, self.output_state_bw) = self.encoder(input_embed) 

    self.out_data = tf.placeholder(tf.int64, [None, None], name="out_data")
    self.target_embed = tf.nn.embedding_lookup(self.decoder_embedding, self.out_data)
    (self.dec_output_fw, self.dec_output_bw), (self.dec_output_state_fw, self.dec_output_state_bw) = self.decoder(self.target_embed,self.output_state_fw,self.output_state_bw)


    self.decoder_findal_output = tf.concat([self.dec_output_fw, self.dec_output_bw], axis=2)
    self.targets = tf.placeholder(tf.int64, [None, None], name="targets")
    self.lg = tf.contrib.layers.fully_connected(self.decoder_findal_output, num_outputs=len(self.vocab), activation_fn=None) 
    self.logits = tf.arg_max(self.lg,dimension=2)

0 个答案:

没有答案