澄清TensorFlow的dynamic_decode sample_ids

时间:2017-06-21 21:54:14

标签: tensorflow

tensorflow.contrib.seq2seq的{​​{1}}会返回三个值,第一个值是带有命名字段dynamic_decode'rnn_output'的2元组。我试图了解sample_id是什么,但我找不到任何示例或文档,而TensorFlow开发者峰会上的example并未添加太多信息。有人可以解释一下吗?

2 个答案:

答案 0 :(得分:0)

sample_id是rnn输出的argmax

答案 1 :(得分:0)

rnn_output=[batch_size, max length of a sentence, probability of each word in a vocabulary]
sample_id = [batch_size, max length of a sentence]

例如:

batch_size is 99 
max length of a sentence is 15
Vocabulary size is 233

rnn_output = [99,15,233]
sample_id = [99,15]

如上所述,sample_id第二维包含rnn_output的第三维的argmax值。

在一种简单的语言中,sample_id的第二维将是rnn_output的第三个维度dimension->max value->index