使用tf.app.run()从类中调用main函数

时间:2018-03-31 22:19:30

标签: tensorflow machine-learning

问题很简单,我正在尝试使用tf.app.run()来调用类中的main函数。但是,以下代码给了我一个错误。任何帮助表示赞赏。

将tensorflow导入为tf     import sys

# Where to find data
tf.app.flags.DEFINE_string('f1', '', 'feature 1')
tf.app.flags.DEFINE_string('f2', '', 'feature 2')

FLAGS = tf.app.flags.FLAGS

class Test(object):
    def __init__(self):
        pass
    def main(self, args):
        print(FLAGS.__flag.iteritems())

def main(args):
    test = Test()
    test.main(args)

if __name__ == '__main__':
    tf.app.run(main)

这是错误:

Traceback (most recent call last):
  File "test.py", line 21, in <module>
    tf.app.run(main)
  File "/Users/yaserkeneshloo/anaconda/envs/env27/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 126, in run
    _sys.exit(main(argv))
  File "test.py", line 18, in main
    test.main(args)
  File "test.py", line 14, in main
    print(FLAGS.__flag.iteritems())
  File "/Users/yaserkeneshloo/anaconda/envs/env27/lib/python2.7/site-packages/tensorflow/python/platform/flags.py", line 85, in __getattr__
    return wrapped.__getattr__(name)
  File "/Users/yaserkeneshloo/anaconda/envs/env27/lib/python2.7/site-packages/absl/flags/_flagvalues.py", line 470, in __getattr__
    raise AttributeError(name)
AttributeError: _Test__flag

3 个答案:

答案 0 :(得分:2)

问题在于,当您尝试从类内部以双下划线开头访问FLAGS的属性时,它会在属性前面加上类名称。因此,它不是FLAGS.__flags而是尝试查找不存在的FLAGS._Test__flags

设置或获取值都是如此。因此,如果您从课程外部设置值,则必须在其前面添加_Test(因为您将课程命名为Test。)如果您在课程中设置了标记,则不必将其设置为__flags。需要前缀,因为它在分配值时也会自动执行前缀。

所以基本上你对你的代码没有任何问题,因为你自己的标志不是以双下划线开头,除非你不能使用内部# Where to find data tf.app.flags.DEFINE_string('_Test__f1', 'feature 1', 'feature 1') tf.app.flags.DEFINE_string('__f2', 'feature 2', 'feature 2') FLAGS = tf.app.flags.FLAGS print( FLAGS.__f2 ) # prints "feature 2" class Test(object): def __init__(self): pass def main(self, args): print( FLAGS.__f1 ) # prints "feature 1" FLAGS.__f1 = 'foobar' # assignment works the same way print( FLAGS.__f1 ) # prints "foobar" print( FLAGS.__f2 ) # AttributeError: _Test__f2 def main(args): test = Test() test.main(args) if __name__ == '__main__': tf.app.run(main) 属性来打印所有标志。您可以单独访问它们。

有关详细示例,请参阅下面的代码。 (另外,DEFINE行中的默认值是第二个参数,而不是第三个参数。)

services:
    _defaults:
        autowire: true      
        autoconfigure: true 
        public: false      
        bind:
            $walletCreds: '%app.wallet_creds%'

答案 1 :(得分:0)

问题是FLAGS没有名为__flag的属性。如果要打印与f1标记对应的字符串,请调用print(FLAGS.f1)

答案 2 :(得分:0)

这样可以解决问题:

{{1}}