Tensorflow:我应该如何正确处理图表?

时间:2016-10-12 04:38:21

标签: python machine-learning tensorflow deep-learning

我正在尝试使用tf.contrib.learn功能设置网络:

#Imports, definition of directory paths..

def main(unused_argv):
    hparams = seg_hparams.create_hparams()

    input_fn_train = seg_inputs.create_input_fn(
            hparams = hparams,
            mode = tf.contrib.learn.ModeKeys.TRAIN,
            input_dir = TRAIN_DATA)

    model_fn = seg_model.create_model_fn(
            hparams,
            model_impl = forward_backward_model)

    estimator = tf.contrib.learn.Estimator(
            model_fn = model_fn,
            model_dir = MODEL_DIR)


    estimator.fit(input_fn=input_fn_train, steps=None)

if __name__ == "__main__":
    tf.app.run()

对于输入,我使用的自定义队列类似于https://indico.io/blog/tensorflow-data-input-part2-extensions/上教程中描述的自定义队列。

运行程序时,遇到以下错误:

Traceback (most recent call last):
  File "train.py", line 54, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 30, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "train.py", line 49, in main
    estimator.fit(input_fn=input_fn_train, steps=None)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 332, in fit
    max_steps=max_steps)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 650, in _train_model
    train_op, loss_op = self._get_train_ops(features, targets)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 951, in _get_train_ops
    _, loss, train_op = self._call_model_fn(features, targets, ModeKeys.TRAIN)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 934, in _call_model_fn
        ....
ValueError: Tensor("random_shuffle_queue_DequeueMany:0", shape=(6, 256, 256, 1), dtype=float32, device=/device:CPU:0) must be from the same graph as Tensor("seg_net/conv1/weights:0", shape=(3, 3, 1, 32), dtype=float32_ref).

基本上我想知道如何确保在不同的功能中使用相同的图形。 我在考虑使用

之类的东西
with tf.Graph().as_default():

在input_fn_train和

with tf.get_default_graph():

在model_fn。

但是,到目前为止,我无法解决这个问题。

0 个答案:

没有答案