tf.cond
可以导出使用SerializeToString
的图表吗?
我正在对CNN模型进行一些预处理,这需要输入大小至少为224,并且我想确保输入图像在小于该大小时调整为224,如果输入图像小于该大小,则不执行任何操作。图像大于那个。
以下代码可以按预期工作。sess.run(...)
结果显示图像已正确调整大小。然后我将整个图形导出(序列化)到.pb
文件中。
import tensorflow as tf
from tensorflow.python.framework import graph_util
input_jpeg = tf.placeholder(tf.string, name='DecodeJpeg/contents')
image = tf.image.decode_jpeg(input_jpeg, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
height = tf.constant(224)
width = tf.constant(224)
def resize_t(image_t, h=height, w=width):
image2 = tf.expand_dims(image_t, axis=0)
image2= tf.image.resize_bilinear(image2, [h, w], align_corners=False)
image2 = tf.squeeze(image2, [0])
return image2
input_h = tf.shape(image)[0]
input_w = tf.shape(image)[1]
# seems logical_or can not work here...why?
#condition = tf.logical_or(tf.less(input_h, height), tf.less(input_w, width))
#image_checked = tf.cond( condition, lambda:resize_t(image), lambda:image)
image_checked = tf.cond(
(tf.less(input_h, height)),
lambda:resize_t(image), lambda:image)
image_checked = tf.cond(
(tf.less(input_h, height)),
lambda:resize_t(image), lambda:image)
shape_t = tf.shape(image_checked, name='shape_t')
with tf.Session() as sess:
# a pic which is smaller than 224x224
with open('1.jpg', 'rb') as f:
content = f.read()
shape_value = sess.run(shape_t, feed_dict={input_jpeg:content})
# with an output:
# [224 224 3]
# it works fine until now ...
print(shape_value)
# write the graph to pbtxt
tf.train.write_graph(sess.graph_def,
'./', 'proto.pbtxt', True)
# freeze the graph and export
output_graph_def = graph_util.convert_variables_to_constants(
sess, sess.graph_def, ['shape_t'])
output_graph_def = graph_util.remove_training_nodes(output_graph_def)
with open('out.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
在成功生成out.pb
文件后,我尝试使用以下代码导入文件:
import tensorflow as tf
with open('out.pb', 'rb') as f:
graph_content = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(graph_content)
_ = tf.import_graph_def(graph_def, name='')
输出操作失败:
Traceback (most recent call last):
File "im.py", line 6, in <module>
_ = tf.import_graph_def(graph_def, name='')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py", line 342, in import_graph_def
% (input_name,)))
ValueError: graph_def is invalid at node u'cond_1/ExpandDims/dim': More inputs specified ('cond_1/Switch:1') than the op expects..
似乎操作expand_dims
收到了意外参数cond_1/Switch:1
,.pbtxt
生成的tf.train.write_graph
文件显示 添加到expand_dims/dim
的其他输入:
node {
name: "cond_1/ExpandDims/dim"
op: "Const"
input: "^cond_1/switch_t"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
应该是(当无条件执行expand_dim / resize时):
node {
name: "ExpandDims/dim"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
我误解了什么吗?如何将tf.cond
与导出的.pb
文件一起使用?