我想使用tf.add_to_collection()来保存我自己的对象,以便以后轻松获取它们。 以下是代码段:
class Model(object):
def __init__(self, scope, is_training=True):
将对象添加到集合:
for i in xrange(num_gpus):
with tf.device("/gpu:%d"%i):
with tf.name_scope("tower_%d"%i) as scope:
m = Model.Model(scope)
tf.add_to_collection("train_model", m)
从集合中获取对象:
models = tf.get_collection("train_model")
代码工作正常,但我收到警告:
WARNING:tensorflow:Error encountered when serializing train_model.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'Model' object has no attribute 'name
我应该做些什么来避免这种警告?
答案 0 :(得分:0)
当您调用tf.train.Saver.save()
时,(可能)会产生警告,它会尝试写出代表tf.Graph
内容的“MetaGraph”,包括所有图表集合的内容
避免警告的最简单方法是在调用write_meta_graph=False
时通过saver.save()
。但是,这使您无需稍后导入MetaGraph。
如果要保存MetaGraph 和以避免警告,则需要实现必要的挂钩(to_proto
和from_proto
)以序列化Model
object作为tf.train.MetaGraphDef
序列化格式的协议缓冲区。 MetaGraph tutorial解释了如何执行此操作,但基本思路如下:
定义描述ModelProto
对象内容的协议缓冲区(Model
)。
定义model_to_proto()
函数,将Model
序列化为ModelProto
:
def model_to_proto(model):
ret = ModelProto()
# Set fields of `ret` from `model`.
return ret
定义model_from_proto()
函数,对ModelProto
进行反序列化并返回Model
:
def model_from_proto(model_proto):
# Construct a `Model` from the fields of `model_proto`.
return Model(...)
注册"train_model"
集合的功能。这当前使用的是一个名为register_proto_function()
from tensorflow.python.framework import ops
ops.register_proto_function("train_model",
proto_type=ModelProto,
to_proto=model_to_proto,
from_proto=model_from_proto)