保存和恢复张量流tf.contrib模型

时间:2016-11-22 22:50:18

标签: python tensorflow

有很多关于如何使用tf.train.Saver对象保存张量流模型的示例,但我想知道如何保存来自tf.contrib的模型。例如,使用https://www.tensorflow.org/versions/master/tutorials/tflearn/index.html上的快速入门演示中的代码:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="/tmp/iris_model")

# Fit model.
classifier.fit(x=training_set.data,
               y=training_set.target,
               steps=2000)

现在我已对classifier进行了培训,如何将模型保存到磁盘,以便稍后我可以编写一个脚本:

# classifier = <some reloading code that reinitializes the trained model>

accuracy_score = classifier.evaluate(x=test_set.data,
                                     y=test_set.target)["accuracy"]

以最小的开销和对前一个脚本的了解?

0 个答案:

没有答案
相关问题