我有一个在pyspark上运行的张量流代码。 代码
tf.flags.DEFINE_integer("evaluate_every", 100, "Evaluate model on dev set after this man
y steps (default: 100)")
tf.flags.DEFINE_integer("window_size", 3, "n-gram")
tf.flags.DEFINE_integer("sequence_length", 204, "max tokens b/w entities")
tf.flags.DEFINE_integer("K", 4, "K-fold cross validation")
tf.flags.DEFINE_float("early_threshold", 0.5, "Threshold to stop the training")
FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
我有一个RDD,我正在调用一个函数get_input()
。但我无法打印 FLAGS 值。
def get_input(row):
FLAGS = tf.flags.FLAGS
print(FLAGS.__flags)
但是我收到了以下错误。
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/worker.py", line 98, in main
command = pickleSer._read_with_length(infile)
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 164, in _read_with_length
return self.loads(obj)
File "/usr/local/spark/python/lib/pyspark.zip/pyspark/serializers.py", line 419, in loads
return pickle.loads(obj, encoding=encoding)
File "/home/sahil/anaconda2/envs/tensorflow3/lib/python3.4/site-packages/tensorflow/python/platform/flags.py", line 47, in __getattr__
if not self.__dict__['__parsed']:
KeyError: '__parsed'
此外,我无法创建张量流对象的广播变量。
例如,如果我写
ones = tf.ones([2,3])
ones = sc.broadcast(ones)
我收到错误。
答案 0 :(得分:0)
使用Python3获取最新的TensorFlow(> 1.4)并使用FLAGS(sys.argv)
,因为不再支持FLAGS._parse_flags()
。
import sys
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
unparsed = FLAGS(sys.argv)