tf.cond无法正确序列化?

时间:2017-03-22 11:09:47

标签: python serialization tensorflow

tf.cond可以导出使用SerializeToString的图表吗? 我正在对CNN模型进行一些预处理,这需要输入大小至少为224,并且我想确保输入图像在小于该大小时调整为224,如果输入图像小于该大小,则不执行任何操作。图像大于那个。

以下代码可以按预期工作。sess.run(...)结果显示图像已正确调整大小。然后我将整个图形导出(序列化)到.pb文件中。

import tensorflow as tf
from tensorflow.python.framework import graph_util

input_jpeg = tf.placeholder(tf.string, name='DecodeJpeg/contents')
image = tf.image.decode_jpeg(input_jpeg, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
height = tf.constant(224)
width = tf.constant(224)
def resize_t(image_t, h=height, w=width):
  image2 = tf.expand_dims(image_t, axis=0)
  image2= tf.image.resize_bilinear(image2, [h, w], align_corners=False)
  image2 = tf.squeeze(image2, [0])
  return image2
input_h = tf.shape(image)[0]
input_w = tf.shape(image)[1]
# seems logical_or can not work here...why?
#condition = tf.logical_or(tf.less(input_h, height), tf.less(input_w, width))
#image_checked = tf.cond( condition, lambda:resize_t(image), lambda:image)

image_checked = tf.cond(
    (tf.less(input_h, height)),
    lambda:resize_t(image), lambda:image)
image_checked = tf.cond(
    (tf.less(input_h, height)),
    lambda:resize_t(image), lambda:image)
shape_t = tf.shape(image_checked, name='shape_t')

with tf.Session() as sess:
  # a pic which is smaller than 224x224
  with open('1.jpg', 'rb') as f:
    content = f.read()
  shape_value = sess.run(shape_t, feed_dict={input_jpeg:content})
  # with an output:
  # [224 224   3]
  # it works fine until now ...
  print(shape_value)

  # write the graph to pbtxt
  tf.train.write_graph(sess.graph_def,
     './', 'proto.pbtxt', True)

  # freeze the graph and export
  output_graph_def = graph_util.convert_variables_to_constants(
        sess, sess.graph_def, ['shape_t'])
  output_graph_def = graph_util.remove_training_nodes(output_graph_def)
  with open('out.pb', 'wb') as f:
    f.write(output_graph_def.SerializeToString())

在成功生成out.pb文件后,我尝试使用以下代码导入文件:

import tensorflow as tf
with open('out.pb', 'rb') as f:
  graph_content = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(graph_content)
_ = tf.import_graph_def(graph_def, name='')

输出操作失败:

Traceback (most recent call last):
  File "im.py", line 6, in <module>
    _ = tf.import_graph_def(graph_def, name='')
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py", line 342, in import_graph_def
    % (input_name,)))
ValueError: graph_def is invalid at node u'cond_1/ExpandDims/dim': More inputs specified ('cond_1/Switch:1') than the op expects..

似乎操作expand_dims收到了意外参数cond_1/Switch:1.pbtxt生成的tf.train.write_graph文件显示 添加到expand_dims/dim的其他输入:

node {
  name: "cond_1/ExpandDims/dim"
  op: "Const"
  input: "^cond_1/switch_t"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}

应该是(当无条件执行expand_dim / resize时):

node {
  name: "ExpandDims/dim"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}

我误解了什么吗?如何将tf.cond与导出的.pb文件一起使用?

0 个答案:

没有答案