在Tensorflow RNN样本中从id获取单词

时间:2016-05-12 06:47:42

标签: python tensorflow

我正在尝试修改Tensorflow的RNN示例。

https://www.tensorflow.org/versions/r0.8/tutorials/recurrent/index.html

在ptb_word_lm.py我想他们正在输入单词索引的int数组(m.input_data:x)。

def run_epoch(session, m, data, eval_op, verbose=False):
  """Runs the model on the given data."""
  epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps
  start_time = time.time()
  costs = 0.0
  iters = 0
  state = m.initial_state.eval()
  for step, (x, y) in enumerate(reader.ptb_iterator(data, m.batch_size,
                                                    m.num_steps)):
    cost, state, _ = session.run([m.cost, m.final_state, eval_op],
                                 {m.input_data: x,
                                  m.targets: y,
                                  m.initial_state: state})

我想看到实际的单词而不是ID,我怎么能看到它们?

1 个答案:

答案 0 :(得分:1)

首先需要保留词汇表(从word到id的索引)。

在main的顶部,保留reader.ptb_raw_data()的第4个返回值,如下所示。

raw_data = reader.ptb_raw_data(FLAGS.data_path)
train_data, valid_data, test_data, vocabulary = raw_data

然后将词汇表传递给run_epoch()。

test_perplexity = run_epoch(session, mtest, test_data, tf.no_op(), vocabulary)

在run_epoch()内部,当你想在x的第一步中将id转换为单词时,

def run_epoch(session, m, data, eval_op, vocabulary, verbose=False):

...
for step, (x, y) in enumerate(...

message ="x: "
for i in range(0, m.num_steps):
    key = vocabulary.keys()[vocabulary.values().index(x[0][i])]
    message += key + " "

print(message)

希望它有所帮助。