如何在tensorflow中使用集合来保存我自己的对象

时间:2017-01-05 05:15:27

标签: tensorflow

我想使用tf.add_to_collection()来保存我自己的对象,以便以后轻松获取它们。 以下是代码段:

class Model(object):
    def __init__(self, scope, is_training=True):

将对象添加到集合:

for i in xrange(num_gpus):
    with tf.device("/gpu:%d"%i):
        with tf.name_scope("tower_%d"%i) as scope:
            m = Model.Model(scope)
            tf.add_to_collection("train_model", m)

从集合中获取对象:

models = tf.get_collection("train_model")

代码工作正常,但我收到警告:

WARNING:tensorflow:Error encountered when serializing train_model.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'Model' object has no attribute 'name

我应该做些什么来避免这种警告?

1 个答案:

答案 0 :(得分:0)

当您调用tf.train.Saver.save()时,(可能)会产生警告,它会尝试写出代表tf.Graph内容的“MetaGraph”,包括所有图表集合的内容

避免警告的最简单方法是在调用write_meta_graph=False时通过saver.save()。但是,这使您无需稍后导入MetaGraph。

如果要保存MetaGraph 以避免警告,则需要实现必要的挂钩(to_protofrom_proto)以序列化Model object作为tf.train.MetaGraphDef序列化格式的协议缓冲区。 MetaGraph tutorial解释了如何执行此操作,但基本思路如下:

  1. 定义描述ModelProto对象内容的协议缓冲区(Model)。

  2. 定义model_to_proto()函数,将Model序列化为ModelProto

    def model_to_proto(model):
        ret = ModelProto()
        # Set fields of `ret` from `model`.
        return ret
    
  3. 定义model_from_proto()函数,对ModelProto进行反序列化并返回Model

    def model_from_proto(model_proto):
        # Construct a `Model` from the fields of `model_proto`.
        return Model(...)
    
  4. 注册"train_model"集合的功能。这当前使用的是一个名为register_proto_function()

    的无证功能
    from tensorflow.python.framework import ops
    
    ops.register_proto_function("train_model",
                                proto_type=ModelProto,
                                to_proto=model_to_proto,
                                from_proto=model_from_proto)