如何让TensorForestEstimator以SaverDef.V2格式保存模型?
当我使用TensorForestEstimator保存图表然后使用freeze_graph冻结它时,我收到以下错误:
Unable to open table file test/checkpoint: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
[[Node: save/RestoreV2_33 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/RestoreV2_33/tensor_names, save/RestoreV2_33/shape_and_slices)]]
似乎TensorForestEstimator在V1中保存模型,因为我在保存模型时收到以下警告:
WARNING:tensorflow:*******************************************************
WARNING:tensorflow:TensorFlow's V1 checkpoint format has been deprecated.
WARNING:tensorflow:Consider switching to the more efficient V2 format:
WARNING:tensorflow: `tf.train.Saver(write_version=tf.train.SaverDef.V2)`
WARNING:tensorflow:now on by default.
WARNING:tensorflow:*******************************************************
TensorForestEstimator中没有参数可以设置和要求V2。这是构造函数:
hparams = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
num_trees=3, max_nodes=1000, num_classes=3, num_features=4).fill()
classifier = TensorForestEstimator(hparams, model_dir='test/')
iris = tf.contrib.learn.datasets.load_iris()
data = iris.data.astype(np.float32)
target = iris.target.astype(np.float32)
classifier.fit(x=data, y=target, steps=100)
这是冻结图表的代码:
from tensorflow.contrib.tensor_forest.python.ops import training_ops
from tensorflow.contrib.tensor_forest.python.ops import inference_ops
training_ops.Load()
inference_ops.Load()
model_name = 'test/'
checkpoint_state_name = "checkpoint"
input_graph_name = "graph.pbtxt"
output_graph_name = "output_graph.pb"
checkpoint_path = os.path.join(model_path, checkpoint_state_name)
input_graph_path = os.path.join(model_path, input_graph_name)
input_saver_def_path = ""
input_binary = False
output_node_names = "output_node"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(model_path, output_graph_name)
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "")