我正在使用a code from OpenAI本应该使用set_shape()
函数将图表使用大小为1的批处理转换为任意大小的批处理。
据说这可能与TF版本< 1.4一起工作但是不再工作了,请参阅this issue。
这是代码:
with tf.gfile.FastGFile(os.path.join(
MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
# Works with an arbitrary minibatch size.
with tf.Session() as sess:
pool3 = sess.graph.get_tensor_by_name('pool_3:0')
ops = pool3.graph.get_operations()
for op_idx, op in enumerate(ops):
for o in op.outputs:
shape = o.get_shape()
shape = [s.value for s in shape]
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.set_shape(tf.TensorShape(new_shape))
这似乎是预期的行为,因为@mrry说in this answer因为set_shape()
应该改善关于形状的信息而不是相反。
那么我将如何以一种相当紧张的方式(一般)将适应批量大小1的图形更改为批量大小?(无需收集所有权重并手动重新定义每个操作)