Tensorflow TensorForestEstimator将模型保存为SaverDef.V2

时间:2016-12-13 18:07:31

标签: tensorflow

如何让TensorForestEstimator以SaverDef.V2格式保存模型?

当我使用TensorForestEstimator保存图表然后使用freeze_graph冻结它时,我收到以下错误:

Unable to open table file test/checkpoint: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
     [[Node: save/RestoreV2_33 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_33/tensor_names, save/RestoreV2_33/shape_and_slices)]]

似乎TensorForestEstimator在V1中保存模型,因为我在保存模型时收到以下警告:

WARNING:tensorflow:*******************************************************
WARNING:tensorflow:TensorFlow's V1 checkpoint format has been deprecated.
WARNING:tensorflow:Consider switching to the more efficient V2 format:
WARNING:tensorflow:   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`
WARNING:tensorflow:now on by default.
WARNING:tensorflow:*******************************************************

TensorForestEstimator中没有参数可以设置和要求V2。这是构造函数:

hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
        num_trees=3, max_nodes=1000, num_classes=3, num_features=4).fill()
classifier = TensorForestEstimator(hparams, model_dir='test/')
iris = tf.contrib.learn.datasets.load_iris()
data = iris.data.astype(np.float32)
target = iris.target.astype(np.float32)
classifier.fit(x=data, y=target, steps=100)

这是冻结图表的代码:

from tensorflow.contrib.tensor_forest.python.ops import training_ops
from tensorflow.contrib.tensor_forest.python.ops import inference_ops

training_ops.Load()
inference_ops.Load()
model_name = 'test/'

checkpoint_state_name = "checkpoint"
        input_graph_name = "graph.pbtxt"
        output_graph_name = "output_graph.pb"
        checkpoint_path = os.path.join(model_path, checkpoint_state_name)

        input_graph_path = os.path.join(model_path, input_graph_name)
        input_saver_def_path = ""
        input_binary = False
        output_node_names = "output_node"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = os.path.join(model_path, output_graph_name)
        clear_devices = False

        freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                                  input_binary, checkpoint_path,
                                  output_node_names, restore_op_name,
                                  filename_tensor_name, output_graph_path,
                                  clear_devices, "")

0 个答案:

没有答案