如何从main()返回变量?

时间:2018-06-01 21:59:08

标签: python tensorflow

TensorFlow是全新的,我正在尝试修改他们给出的一些例子。例如:

https://github.com/tensorflow/tensorflow/blob/4806cb0646bd21f713722bd97c0d0262c575f7e0/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py

"""Simple MNIST classifier example with JIT XLA and timelines.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.client import timeline

FLAGS = None


def main(_):
  # Import data
  mnist = input_data.read_data_sets(FLAGS.data_dir)

  # Create the model
  x = tf.placeholder(tf.float32, [None, 784])
  w = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.matmul(x, w) + b
  ....
  ....
  ....
  ....

  # Test trained model
  correct_prediction = tf.equal(tf.argmax(y, 1), y_)
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print(sess.run(accuracy,
                 feed_dict={x: mnist.test.images,
                            y_: mnist.test.labels}))
  sess.close()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/tensorflow/mnist/input_data',
      help='Directory for storing input data')
  parser.add_argument(
      '--xla', type=bool, default=True, help='Turn xla via JIT on')
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

这将打印" 0.9202"在命令行上。如何返回值以便我可以在其他函数中使用它?

 val = tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
 print(val)

我明白了:

[pylint] E1111:Assigning to function call which doesn't return

此外,在执行该功能后无法执行任何操作。如果我尝试打印("此字符串"),则程序会在打印之前退出。

编辑: 答案到目前为止给出了同样的错误:

[pylint] E1111:Assigning to function call which doesn't return

我查看了TF的很多示例,但找不到如何返回值而不是将其打印到控制台的示例。

2 个答案:

答案 0 :(得分:0)

您首先gess是正确的,将val =添加到您的函数调用中。但是你也应该从函数中返回一个值。

替换:

print(sess.run(accuracy,
               feed_dict={x: mnist.test.images,
                        y_: mnist.test.labels}))

由:

return sess.run(accuracy,
             feed_dict={x: mnist.test.images,
                        y_: mnist.test.labels})

答案 1 :(得分:0)

在这里查看How does tf.app.run() work?

本质上tf.app.run是一个使用一些参数调用main的包装器。您可以更改打印状态以将结果分配给变量并将其返回或在主要内部调用您自己的功能以写入某些位置

...
result = (sess.run(accuracy,
             feed_dict={x: mnist.test.images,
                        y_: ...
sess.close()
return result

也许本教程的完整代码使这一点更加清晰:

https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/examples/tutorials/layers/cnn_mnist.py