Tensorflow:对tf.estimator.inputs.numpy_input_fn函数进行故障排除

时间:2017-10-16 04:27:15

标签: python tensorflow

我正在运行text classification

中的一些教程代码

我可以运行脚本并且它可以工作但是当我尝试逐行运行它以试图理解每个步骤正在做什么时,我对此步骤感到有些困惑:

test_input_fn = tf.estimator.inputs.numpy_input_fn(
  x={WORDS_FEATURE: x_test},
  y=y_test,
  num_epochs=1,
  shuffle=False)
classifier.train(input_fn=train_input_fn, steps=100)

我从概念上知道train_input_fn正在向训练函数提供数据,但我如何手动调用此fn来检查其中的内容?

我已经跟踪了代码并发现train_input_fn函数将数据提供给以下2个变量:

features
Out[15]: {'words': <tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(560, 10) dtype=int64>}

labels
Out[16]: <tf.Tensor 'random_shuffle_queue_DequeueMany:2' shape=(560,) dtype=int32>

当我尝试通过执行sess.run(功能)来评估功能变量时,我的终端似乎卡住并停止响应。

检查这些变量内容的正确方法是什么?

谢谢!

1 个答案:

答案 0 :(得分:2)

基于numpy_input_fn documentation和行为(挂起),我想底层实现依赖于队列运行器。当队列跑步者没有开始时发生挂起。尝试根据this guide

将会话运行脚本修改为以下内容
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    try:
        for step in xrange(1000000):
            if coord.should_stop():
                break
            features_data = sess.run(features)
            print(features_data)

    except Exception, e:
        # Report exceptions to the coordinator.
        coord.request_stop(e)
    finally:
        # Terminate as usual. It is safe to call `coord.request_stop()` twice.
        coord.request_stop()
        coord.join(threads)

或者,我建议您查看tf.data.Dataset界面(张量流1.3或之前可能tf.contrib.data.Dataset)。您可以在不使用Dataset.from_tensor_slices的队列的情况下获得类似的输入/标签张量。创建稍微复杂一些,但界面更灵活,实现不使用队列运行器,这意味着会话运行更简单。

import tensorflow as tf
import numpy as np

x_data = np.random.random((100000, 2))
y_data = np.random.random((100000,))

batch_size = 2
buff = 100


def input_fn():
    # possible tf.contrib.data.Dataset.from... in tf 1.3 or earlier
    dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
    dataset = dataset.repeat().shuffle(buff).batch(batch_size)
    x, y = dataset.make_one_shot_iterator().get_next()
    return x, y


x, y = input_fn()
with tf.Session() as sess:
    print(sess.run([x, y]))