使用Tensorflow模型保存类

时间:2019-02-08 08:58:55

标签: tensorflow

假设我有一个封装在类tensorflow模型中,该模型以类似的方式定义:

class Model:
    def __init__(self, ...):
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)

        <some flags, numbers, numpy arrays>
        <some tf variables and placeholders>
        <tf initialization>

保存和恢复具有所有属性的模型的最佳实践是什么?

1 个答案:

答案 0 :(得分:0)

我不确定是否有最佳做法。但是这里是使用simple-save时要提防的地方:

迭代器和数据集:

请注意数据集上的迭代器,因为可能无法从保存的图形中还原迭代器。由于恢复时它可能会初始化一个新的。

不要

# Save
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
iterator = dataset.make_initializable_iterator()
<iterator get next>

# Restore
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
iterator = dataset.make_initializable_iterator()

执行

# Save
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_init_op = iterator.make_initializer(dataset, name='dataset_initializer')

# Resotre
dataset_init_op = graph.get_operation_by_name('dataset_initializer')

随时在下面添加点。