将cifar10模型从检查点文件导出到tensorflow服务

时间:2016-06-24 18:07:09

标签: tensorflow tensorflow-serving

我尝试修改CIFAR10模型的inception_export.py,但是我收到了错误:

raise type(e)(node_def, op, message) tensorflow.python.framework.errors.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [18,384] rhs shape= [2304,384]
     [[Node: save/Assign_5 = Assign[T=DT_FLOAT, _class=["loc:@local3/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](local3/weights, save/restore_slice_5)]] Caused by op u'save/Assign_5', defined at:  

我对tensorflow仍然很新,任何帮助都非常感谢,谢谢

EDIT1:这是我的代码。我还没有安装tensorflow服务,因此相关的块被注释掉了。我还将image_size更改为24以适合CIFAR10模型。

# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

#!/usr/bin/env python2.7
"""Modified for CIFAR10 model from https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/inception_export.py
"""

import os.path
import sys

# This is a placeholder for a Google-internal import.

import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10
#from inception import inception_model

#from tensorflow_serving.session_bundle import exporter


tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
                           """Directory where to read training checkpoints.""")
tf.app.flags.DEFINE_string('export_dir', '/tmp/cifar10_export',
                           """Directory where to export inference model.""")
tf.app.flags.DEFINE_integer('image_size', 24,
                            """Needs to provide same value as in training.""")
FLAGS = tf.app.flags.FLAGS


NUM_CLASSES = 10
NUM_TOP_CLASSES = 2

WORKING_DIR = os.path.dirname(os.path.realpath(__file__))
SYNSET_FILE = os.path.join(WORKING_DIR, 'imagenet_lsvrc_2015_synsets.txt')
METADATA_FILE = os.path.join(WORKING_DIR, 'imagenet_metadata.txt')


def export():
  """can be deleted if my simply define the constant string manually below?
  # Create index->synset mapping
  synsets = []
  with open(SYNSET_FILE) as f:
    synsets = f.read().splitlines()
  # Create synset->metadata mapping
  texts = {}
  with open(METADATA_FILE) as f:
    for line in f.read().splitlines():
      parts = line.split('\t')
      assert len(parts) == 2
      texts[parts[0]] = parts[1]
  """
  with tf.Graph().as_default():
    # Build inference model.
    # Please refer to Tensorflow inception model for details.

    # Input transformation.
    # TODO(b/27776734): Add batching support.
    jpegs = tf.placeholder(tf.string, shape=(1))
    image_buffer = tf.squeeze(jpegs, [0])
    # Decode the string as an RGB JPEG.
    # Note that the resulting image contains an unknown height and width
    # that is set dynamically by decode_jpeg. In other words, the height
    # and width of image is unknown at compile-time.
    image = tf.image.decode_jpeg(image_buffer, channels=3)
    # After this point, all image pixels reside in [0,1)
    # until the very end, when they're rescaled to (-1, 1).  The various
    # adjust_* ops all require this range for dtype float.
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    # Crop the central region of the image with an area containing 87.5% of
    # the original image.
    image = tf.image.central_crop(image, central_fraction=0.875)
    # Resize the image to the original height and width.
    image = tf.expand_dims(image, 0)
    image = tf.image.resize_bilinear(image,
                                     [FLAGS.image_size, FLAGS.image_size],
                                     align_corners=False)
    image = tf.squeeze(image, [0])
    # Finally, rescale to [-1,1] instead of [0, 1)
    image = tf.sub(image, 0.5)
    image = tf.mul(image, 2.0)
    images = tf.expand_dims(image, 0)

    # Run inference.
    logits = cifar10.inference(images)

    # Transform output to topK result.
    values, indices = tf.nn.top_k(logits, NUM_TOP_CLASSES)

    # Create a constant string Tensor where the i'th element is
    # the human readable class description for the i'th index.
    # Note that the 0th index is an unused background class
    # (see inception model definition code).
    class_descriptions =  ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    # for s in synsets:
      # class_descriptions.append(texts[s])
    class_tensor = tf.constant(class_descriptions)

    classes = tf.contrib.lookup.index_to_string(tf.to_int64(indices),
                                                mapping=class_tensor)

    # Restore variables from training checkpoint.
    variable_averages = tf.train.ExponentialMovingAverage(
        cifar10.MOVING_AVERAGE_DECAY)
    variables_to_restore = variable_averages.variables_to_restore()
    saver = tf.train.Saver(variables_to_restore)
    with tf.Session() as sess:
      # Restore variables from training checkpoints.
      ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
      if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        # Assuming model_checkpoint_path looks something like:
        #   /my-favorite-path/imagenet_train/model.ckpt-0,
        # extract global_step from it.
        global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        print('Successfully loaded model from %s at step=%s.' %
              (ckpt.model_checkpoint_path, global_step))
      else:
        print('No checkpoint file found at %s' % FLAGS.checkpoint_dir)
        return
      """ Not exporting yet because I haven't installed tensorflow serving
      # Export inference model.
      init_op = tf.group(tf.initialize_all_tables(), name='init_op')
      model_exporter = exporter.Exporter(saver)
      signature = exporter.classification_signature(
          input_tensor=jpegs, classes_tensor=classes, scores_tensor=values)
      model_exporter.init(default_graph_signature=signature, init_op=init_op)
      model_exporter.export(FLAGS.export_dir, tf.constant(global_step), sess)
      print('Successfully exported model to %s' % FLAGS.export_dir)
      """

def main(unused_argv=None):
  export()


if __name__ == '__main__':
  tf.app.run()

0 个答案:

没有答案