在张量流中对非文本任务实施波束搜索

时间:2018-08-09 04:00:01

标签: tensorflow lstm rnn beam-search

说我有一个用dynamic_rnn训练过的rnn细胞:

rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units)

# input_tensor: [batch_size, max_time, feature_size], each element of the last dimension is an sample
# rnn outputs: [batch_size, max_time, num_units]
# final_state: [batch_size, num_units]
rnn_outputs, final_state = tf.nn.dynamic_rnn(
        rnn_cell, input_tensor, dtype=tf.float32, 
        sequence_length=seq_length, time_major=False)

logits = tf.layers.dense(
        final_state.h, 2, 
        activation=None, 
        use_bias=True, 
        name='logits') # batch_size * 2

loss = tf.reduce_mean(
    tf.losses.sparse_softmax_cross_entropy(
        logits=logits, # batch_size * 2
        labels=labels))

# consider prob of class '1' as a score
score = tf.nn.softmax(logits)[:,1]

现在,给定一个input_tensor进行预测,我想进行波束搜索以找到输入样本的最佳排名,以使排名样本的总得分最大。

0 个答案:

没有答案