无法识别Tensorflow标志

时间:2017-02-07 10:22:36

标签: python tensorflow pyspark

我有一个在pyspark上运行的张量流代码。 代码

tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this man

    y steps (default: 100)")
    tf.flags.DEFINE_integer("window_size", 3, "n-gram")
    tf.flags.DEFINE_integer("sequence_length", 204, "max tokens b/w entities")
    tf.flags.DEFINE_integer("K", 4, "K-fold cross validation")
    tf.flags.DEFINE_float("early_threshold", 0.5, "Threshold to stop the training")

FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()

我有一个RDD,我正在调用一个函数get_input()。但我无法打印 FLAGS 值。

def get_input(row):
    FLAGS = tf.flags.FLAGS
    print(FLAGS.__flags)

但是我收到了以下错误。

org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/worker.py", line 98, in main
    command = pickleSer._read_with_length(infile)
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 164, in _read_with_length
    return self.loads(obj)
  File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 419, in loads
    return pickle.loads(obj, encoding=encoding)
  File "/home/sahil/anaconda2/envs/tensorflow3/lib/python3.4/site-packages/tensorflow/python/platform/flags.py", line 47, in __getattr__
    if not self.__dict__['__parsed']:
KeyError: '__parsed'

此外,我无法创建张量流对象的广播变量。

例如,如果我写

ones = tf.ones([2,3])
ones = sc.broadcast(ones)

我收到错误。

1 个答案:

答案 0 :(得分:0)

使用Python3获取最新的TensorFlow(> 1.4)并使用FLAGS(sys.argv),因为不再支持FLAGS._parse_flags()

import sys
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
unparsed = FLAGS(sys.argv)