在Tensorflow中序列化/反序列化FLAGS

时间:2019-02-05 23:52:51

标签: tensorflow

我想序列化我的tensorflow配置标志,以便它们存储在文件中,以后我可以重新加载它们。 json模块反对序列化“ Flag”类型。我尝试使用标志nameFLAGS.__flags[name].value构建新字典,但是有些标志是嵌套字典。 似乎我正在尝试重新发明轮子。有具体的配置序列化/反序列化示例吗?

TypeError: Object of type 'Flag' is not JSON serializable

1 个答案:

答案 0 :(得分:1)

您可能想尝试这样的事情:

def flag_to_dict(FLAGS):
    if tf.__version__ == '1.5':
        flag_dict = FLAGS.flag_values_dict()
    else:
        flag_dict = FLAGS.__flags
    return flag_dict

但是,在一些最新的TensorFlow版本中,我仍然遇到相同的错误!在那种情况下,自定义序列化程序(连同上面的代码)解决了这个问题:

class TfAwareEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, tf.flags.Flag):
            return obj.value
        else:
            return super(TfAwareEncoder, self).default(obj)
# ...
json.dump(flag_dict, open_file, indent=4, sort_keys=True, cls=TfAwareEncoder)