Tensorflow保存空模型对象

时间:2018-11-17 01:47:26

标签: python tensorflow save

我正在训练随机森林分类器,并希望将模型另存为protobuf(.pb)对象,以供以后使用。尝试执行此操作时,我得到的模型对象大小为3 kb(显然不包含任何内容)。我使用的相关代码部分如下所示:

params = tensor_forest.ForestHParams\
    (num_features=len(FEATURE_COLUMNS),\
    num_classes=num_classes,\
    regression=False,\
    num_trees=num_trees,\
    min_split_samples=min_split_samples,\
    max_nodes=max_nodes).fill()

with tf.Session() as sess:
    VK = random_forest.TensorForestEstimator(
                        params,model_dir=model_dir, config=tf.contrib.learn.RunConfig(save_checkpoints_secs=60))
    VK.fit(input_fn=train_input_fn, steps=steps)

results = VK.evaluate(input_fn=eval_input_fn, steps=1) #metrics=validation_metrics)
ref_res_predict = list(VK.predict(input_fn=eval_input_fn))

values = [d['probabilities'][d['classes']] for d in ref_res_predict]

export_dir = 'rf_saved_model_2'
print("Exporting trained mode to ", export_dir)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
feature_configs = {'x': tf.FixedLenFeature(shape=[101], dtype=tf.float32),}
tf_example = tf.parse_example(serialized_tf_example, feature_configs)
x = tf.identity(tf_example['x'], name='x')
y = tf.placeholder('float', shape=[None, 3])

classification_inputs = tf.saved_model.utils.build_tensor_info(x)
classification_outputs_classes = tf.saved_model.utils.build_tensor_info(y)
classification_outputs_scores = tf.saved_model.utils.build_tensor_info(tf.convert_to_tensor(values))

classification_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={
              tf.saved_model.signature_constants.CLASSIFY_INPUTS:
                  classification_inputs
          },
          outputs={
              tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
                  classification_outputs_classes,
              tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
                  classification_outputs_scores
          },
          method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME))

#classification_signature
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
        inputs={'csv': tensor_info_x},
        outputs={'scores': tensor_info_y},
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

#prediction_signature
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
        'predict':
        prediction_signature,
        tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
        classification_signature,
},
legacy_init_op=legacy_init_op)

builder.save()

0 个答案:

没有答案