将训练有素的Tensorflow模型参数导出为SavedModel格式

时间:2017-07-27 19:59:54

标签: tensorflow google-cloud-ml-engine

我已经构建了一个利用Google ML Engine来训练各种文本分类器的系统,该系统使用简单的平面CNN架构(借鉴了优秀的WildML post)。我还大量使用存在here的ML Engine培训师模板 - 特别是使用Tensorflow核心功能。

我的问题是,当模型正确地训练和学习参数时,我无法以二进制SavedModel格式(即 - .pb文件)获得序列化导出以维持学习的权重。我可以通过在模型gcloud predict local文件夹上使用export API并且每次进行随机预测来告诉我这一点 - 让我相信当图形结构被保存为proto-buf格式时,检查点文件中的相关权重未被转移。

这是我run功能的代码:

def run(...):

  # ... code to load and transform train/test data

  with train_graph.as_default():
  with tf.Session(graph=train_graph).as_default() as session:

      # Features and label tensors as read using filename queue
      features, labels = model.input_fn(
          x_train,
          y_train,
          num_epochs=num_epochs,
          batch_size=train_batch_size
      )

      # Returns the training graph and global step tensor
      tf.logging.info("Train vocab size: {:d}".format(vocab_size))
      train_op, global_step_tensor, cnn, train_summaries = model.model_fn(
          model.TRAIN,
          sequence_length,
          num_classes,
          label_values,
          vocab_size,
          embedding_size,
          filter_sizes,
          num_filters
      )
      tf.logging.info("Created simple training CNN with ({}) filter types".format(filter_sizes))

      # Setup writers
      train_summary_op = tf.summary.merge(train_summaries)
      train_summary_dir = os.path.join(job_dir, "summaries", "train")

      # Generate writer
      train_summary_writer = tf.summary.FileWriter(train_summary_dir, session.graph)

      # Initialize all variables
      session.run(tf.global_variables_initializer())
      session.run(tf.local_variables_initializer())

      model_dir = os.path.abspath(os.path.join(job_dir, "model"))
      if not os.path.exists(model_dir):
          os.makedirs(model_dir)
      saver = tf.train.Saver()

      def train_step(x_batch, y_batch):
          """
          A single training step
          """
          feed_dict = {
            cnn.input_x: x_batch,
            cnn.input_y: y_batch,
            cnn.dropout_keep_prob: 0.5
          }

          step, _, loss, accuracy = session.run([global_step_tensor, train_op, cnn.loss, cnn.accuracy],
                                                feed_dict=feed_dict)
          time_str = datetime.datetime.now().isoformat()
          if step % 10 == 0:
              tf.logging.info("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))

          # Return current step
          return step

      def eval_step(x_batch, y_batch, train_step, total_steps):
            """
            Evaluates model on a dev set
            """
            feed_dict = {
              cnn.input_x: x_batch,
              cnn.input_y: y_batch,
              cnn.dropout_keep_prob: 1.0
            }
            step, loss, accuracy, scores, predictions = session.run([global_step_tensor, cnn.loss, cnn.accuracy, cnn.scores, cnn.predictions],
                                                feed_dict=feed_dict)

            # Get metrics
            y_actual = np.argmax(y_batch, 1)
            model_metrics = precision_recall_fscore_support(y_actual, predictions)

            #print(scores)
            time_str = datetime.datetime.now().isoformat()
            print("\n---- EVAULATION ----")
            avg_precision = np.mean(model_metrics[0], axis=0)
            avg_recall = np.mean(model_metrics[1], axis=0)
            avg_f1 = np.mean(model_metrics[2], axis=0)
            print("{}: step {}, loss {:g}, acc {:g}, prec {:g}, rec {:g}, f1 {:g}".format(time_str, step, loss, accuracy, avg_precision, avg_recall, avg_f1))
            print("Model metrics: ", model_metrics)
            print("---- EVALUATION ----\n")


      # Generate batches
      batches = data_helpers.batch_iter(
          list(zip(features, labels)), train_batch_size, num_epochs)

      # Training loop. For each batch...
      for batch in batches:
          x_batch, y_batch = zip(*batch)
          current_step = train_step(x_batch, y_batch)

          if current_step % 20 == 0 or current_step == 1:
              eval_step(x_eval, y_eval, current_step, total_steps)

      # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
      print(model_dir)
      trained_model = saver.save(session, os.path.join(job_dir, 'model') + "/model.ckpt", global_step=current_step)
      print(trained_model)
      print("Saved final model checkpoint to {}".format(trained_model))

      # Only perform this if chief
      if is_chief:
          build_and_run_exports(trained_model, job_dir,
                              model.SERVING_INPUT_FUNCTIONS[model.TEXT],
                              sequence_length, num_classes, label_values,
                              vocab_size, embedding_size, filter_sizes,
                              num_filters, vocab_processor)

我的build_and_run_exports函数:

def build_and_run_exports(...):
    # Check if we export already exists - if so delete
    export_dir = os.path.join(job_dir, 'export')
    if os.path.exists(export_dir):
      print("Export currently exists - going to delete:", export_dir)
      shutil.rmtree(export_dir)

    # Create exporter
    exporter = tf.saved_model.builder.SavedModelBuilder(export_dir)

    # Restore prediction graph
    prediction_graph = tf.Graph()
    with prediction_graph.as_default():

      with tf.Session(graph=prediction_graph) as session:
            # Get training data
            features, inputs_dict = serving_input_fn()

            # Setup inputs
            inputs_info = {
                name: tf.saved_model.utils.build_tensor_info(tensor)
                for name, tensor in inputs_dict.iteritems()
            }

            # Load model
            cnn = TextCNN(
                sequence_length=sequence_length,
                num_classes=num_classes,
                vocab_size=vocab_size,
                embedding_size=embedding_size,
                filter_sizes=list(map(int, filter_sizes.split(","))),
                num_filters=num_filters,
                input_tensor=features)

            # Restore model
            saver = tf.train.Saver()
            saver.restore(session, latest_checkpoint)

            # Setup outputs
            outputs = {
                'logits': cnn.scores,
                'probabilities': cnn.probabilities,
                'predicted_indices': cnn.predictions
            }

            # Create output info
            output_info = {
                name: tf.saved_model.utils.build_tensor_info(tensor)
                for name, tensor in outputs.iteritems()
            }

            # Setup signature definition
            signature_def = tf.saved_model.signature_def_utils.build_signature_def(
                inputs=inputs_info,
                outputs=output_info,
                method_name=sig_constants.PREDICT_METHOD_NAME
            )

            # Create graph export
            exporter.add_meta_graph_and_variables(
                session,
                tags=[tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
                },
                legacy_init_op=tf.saved_model.main_op.main_op()
            )

            # Export model
            exporter.save()

最后,但并非最不重要的是,TextCNN模型:

class TextCNN(object):
    """
    A CNN for text classification.
    Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer.
    """
    def __init__(
      self, sequence_length, num_classes, vocab_size,
      embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0,
      dropout_keep_prob=0.5, input_tensor=None):

        # Setup input
        if input_tensor != None:
            self.input_x = input_tensor
            self.dropout_keep_prob = tf.constant(1.0)
        else:
            self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
            self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")

        # Placeholders for input, output and dropout
        self.input_y = tf.placeholder(tf.int32, [None, num_classes], name="input_y")

        # Keeping track of l2 regularization loss (optional)
        l2_loss = tf.constant(0.0)

        # Embedding layer
        with tf.device('/cpu:0'), tf.name_scope("embedding"):
            self.W = tf.Variable(
                tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
                name="W")
            self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
            self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

        # Create a convolution + maxpool layer for each filter size
        pooled_outputs = []
        for i, filter_size in enumerate(filter_sizes):
            with tf.name_scope("conv-maxpool-%s" % filter_size):
                # Convolution Layer
                filter_shape = [filter_size, embedding_size, 1, num_filters]
                W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
                b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b")
                conv = tf.nn.conv2d(
                    self.embedded_chars_expanded,
                    W,
                    strides=[1, 1, 1, 1],
                    padding="VALID",
                    name="conv")
                # Apply nonlinearity
                h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
                # Maxpooling over the outputs
                pooled = tf.nn.max_pool(
                    h,
                    ksize=[1, sequence_length - filter_size + 1, 1, 1],
                    strides=[1, 1, 1, 1],
                    padding='VALID',
                    name="pool")
                pooled_outputs.append(pooled)

        # Combine all the pooled features
        num_filters_total = num_filters * len(filter_sizes)
        self.h_pool = tf.concat(pooled_outputs, 3)
        self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])

        # Add dropout
        with tf.name_scope("dropout"):
            self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)

        # Final (unnormalized) scores and predictions
        with tf.name_scope("output"):
            W = tf.get_variable(
                "W",
                shape=[num_filters_total, num_classes],
                initializer=tf.contrib.layers.xavier_initializer())
            b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
            l2_loss += tf.nn.l2_loss(W)
            l2_loss += tf.nn.l2_loss(b)
            self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores")
            self.predictions = tf.argmax(self.scores, 1, name="predictions")

        # CalculateMean cross-entropy loss
        with tf.name_scope("loss"):
            losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y)
            self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss

        with tf.name_scope("probabilities"):
            self.probabilities = tf.nn.softmax(logits=self.scores)

        # Accuracy
        with tf.name_scope("accuracy"):
            correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
            self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")

我希望我在如何创建TF图表/会话和恢复统计数据方面遗漏一些简单的内容。

提前感谢您的帮助!

1 个答案:

答案 0 :(得分:2)

此行为是由于tf.saved_model.main_op.main_op()的行为导致的,该行为随机初始化图表中的所有变量(code)。但是,legacy_init_op在从检查点恢复变量后发生(恢复发生here,然后是legacy_init_op here)。

解决方案只是不重新初始化所有变量,例如,在您的代码中:

from tensorflow.python.ops import variables
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import control_flow_ops

def my_main_op():
  init_local = variables.local_variables_initializer()
  init_tables = lookup_ops.tables_initializer()
  return control_flow_ops.group(init_local, init_tables)

def build_and_run_exports(...):
  ...
            # Create graph export
            exporter.add_meta_graph_and_variables(
                session,
                tags=[tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    sig_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_def
                },
                legacy_init_op=my_main_op()
            )

            # Export model
            exporter.save()