TensorFlow:tf.Estimator模型的输入节点是什么

时间:2018-01-19 07:00:40

标签: python tensorflow deep-learning tensorflow-serving tensorflow-estimator

我训练了一个广泛的&深度模型使用预先制作的Estimator类(DNNLinearCombinedClassifier),基本上遵循tensorflow.org上的tutorial

我想做推理/服务,但不使用tensorflow服务。这基本上归结为将一些测试数据馈送到正确的输入张量并检索输出张量。

但是,我不确定输入节点/层应该是什么。在张量流图(graph.pbtxt)中,以下节点似乎相关。但它们也与输入队列有关,主要在训练期间使用,但不一定是推理(我一次只能发送一个实例)。

  name: "enqueue_input/random_shuffle_queue"
  name: "enqueue_input/Placeholder"
  name: "enqueue_input/Placeholder_1"
  name: "enqueue_input/Placeholder_2"
  ...
  name: "enqueue_input/Placeholder_84"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_1"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_2"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_3"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_4"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany"
  name: "enqueue_input/sub/y"
  name: "enqueue_input/sub"
  name: "enqueue_input/Maximum/x"
  name: "enqueue_input/Maximum"
  name: "enqueue_input/Cast"
  name: "enqueue_input/mul/y"
  name: "enqueue_input/mul"

有谁知道答案?提前谢谢!

1 个答案:

答案 0 :(得分:4)

如果要进行推理,但不使用张量流服务,则可以使用tf.estimator.Estimator预测方法。

但是,如果您想手动进行操作(这样可以更快地运行),则需要一种解决方法。我不确定我所做的是否是最好的方法,但是它确实有效。这是我的解决方法。

1)让我们进行导入并创建变量和伪数据:

import os
import numpy as np
from functools import partial
import pickle
import tensorflow as tf

N = 10000
EPOCHS = 1000
BATCH_SIZE = 2

X_data = np.random.random((N, 10))
y_data = (np.random.random((N, 1)) >= 0.5).astype(int)

my_dir = os.getcwd() + "/"

2)定义一个input_fn,您将使用tf.data.Dataset。将张量名称保存在字典中(“ input_tensor_map”),该字典将输入键映射到张量名称。

def my_input_fn(X, y=None, is_training=False):

    def internal_input_fn(X, y=None, is_training=False):

        if (not isinstance(X, dict)):
            X = {"x": X}

        if (y is None):
            dataset = tf.data.Dataset.from_tensor_slices(X)
        else:
            dataset = tf.data.Dataset.from_tensor_slices((X, y))

        if (is_training):
            dataset = dataset.repeat().shuffle(100)
            batch_size = BATCH_SIZE
        else:
            batch_size = 1

        dataset = dataset.batch(batch_size)

        dataset_iter = dataset.make_initializable_iterator()

        if (y is None):
            features = dataset_iter.get_next()
            labels = None
        else:
            features, labels = dataset_iter.get_next()

        input_tensor_map = dict()
        for input_name, tensor in features.items():
            input_tensor_map[input_name] = tensor.name

        with open(os.path.join(my_dir, 'input_tensor_map.pickle'), 'wb') as f:
            pickle.dump(input_tensor_map, f, protocol=pickle.HIGHEST_PROTOCOL)

        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, dataset_iter.initializer)

        return (features, labels) if (not labels is None) else features

    return partial(internal_input_fn, X=X, y=y, is_training=is_training)

3)定义要在tf.estimator.Estimator中使用的模型。例如:

def my_model_fn(features, labels, mode):

    output = tf.layers.dense(inputs=features["x"], units=1, activation=None)
    logits = tf.identity(output, name="logits")
    prediction = tf.nn.sigmoid(logits, name="predictions")
    classes = tf.to_int64(tf.greater(logits, 0.0), name="classes")

    predictions_dict = {
                "class": classes,
                "probabilities": prediction
                }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions_dict)

    one_hot_labels = tf.squeeze(tf.one_hot(tf.cast(labels, dtype=tf.int32), 2))
    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=one_hot_labels, logits=logits)

    tf.summary.scalar("loss", loss)

    accuracy = tf.reduce_mean(tf.to_float(tf.equal(labels, classes)))
    tf.summary.scalar("accuracy", accuracy)

    # Configure the Training Op (for TRAIN mode)
    if (mode == tf.estimator.ModeKeys.TRAIN):
        train_op = tf.train.AdamOptimizer().minimize(loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss)

4)训练并冻结模型。冻结方法来自TensorFlow: How to freeze a model and serve it with a python API,我做了一个小的修改。

def freeze_graph(output_node_names):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 
    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
    """
    if (output_node_names is None):
        output_node_names = 'loss'

    if not tf.gfile.Exists(my_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % my_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(my_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

    return output_graph_def

# *****************************************************************************

tf.logging.set_verbosity(tf.logging.INFO)

estimator = tf.estimator.Estimator(model_fn=my_model_fn, model_dir=my_dir)

if (estimator.latest_checkpoint() is None):
    estimator.train(input_fn=my_input_fn(X=X_data, y=y_data, is_training=True), steps=EPOCHS)
    freeze_graph("predictions,classes")

tf.logging.set_verbosity(tf.logging.INFO)

estimator = tf.estimator.Estimator(model_fn=my_model_fn, model_dir=my_dir)

if (estimator.latest_checkpoint() is None):
    estimator.train(input_fn=my_input_fn(X=X_data, y=y_data, is_training=True), steps=EPOCHS)
    freeze_graph("predictions,classes")

5)最后,您可以使用冻结的图进行推断,输入张量名称在保存的字典中。同样,从TensorFlow: How to freeze a model and serve it with a python API加载冻结模型的方法。

def load_frozen_graph(prefix="frozen_graph"):
    frozen_graph_filename = os.path.join(my_dir, "frozen_model.pb")

    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and returns it 
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name=prefix)

    return graph

# *****************************************************************************

X_test = {"x": np.random.random((int(N/2), 10))}

prefix = "frozen_graph"
graph = load_frozen_graph(prefix)

for op in graph.get_operations():
    print(op.name)

with open(os.path.join(my_dir, 'input_tensor_map.pickle'), 'rb') as f:
    input_tensor_map = pickle.load(f)

with tf.Session(graph=graph) as sess:
    input_feed = dict()

    for key, tensor_name in input_tensor_map.items():
        tensor = graph.get_tensor_by_name(prefix + "/" + tensor_name)
        input_feed[tensor] = X_test[key]

    logits = graph.get_operation_by_name(prefix + "/logits").outputs[0]
    probabilities = graph.get_operation_by_name(prefix + "/predictions").outputs[0]
    classes = graph.get_operation_by_name(prefix + "/classes").outputs[0]

    logits_values, probabilities_values, classes_values = sess.run([logits, probabilities, classes], feed_dict=input_feed)