当可变形状改变时从检查点恢复

时间:2016-02-23 18:10:11

标签: tensorflow

我无法恢复包含改变形状的变量的检查点模型。例如,使用这个简单的模型:

var = tf.get_variable(initializer=tf.constant_initializer([0]), shape=[1], trainable=False, name='var')
op = tf.assign(var, [1, 2], validate_shape=False)
saver = tf.train.Saver(reshape=False)

如果我运行op然后保存模型,当我尝试恢复它时,我收到以下错误:

Assign requires shapes of both tensors to match. lhs shape= [1] rhs shape= [2]
 [[Node: save/Assign = Assign[T=DT_FLOAT, use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](var, save/restore_slice)]]

这似乎与变化的形状和Saver试图验证形状有关。如果我在构建reshape时将True设置为Saver,根据文档应解决此问题,我会收到此错误:

Input to reshape is a tensor with 2 values, but the requested shape has 1
 [[Node: save/Reshape = Reshape[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](save/restore_slice, save/Reshape/shape)]]

我倾向于认为这是一个错误。

2 个答案:

答案 0 :(得分:1)

Saver的重塑选项仅在形状具有相同的元素总数时才有效。例如,它允许您从形状为[1]的数据加载形状为[]的变量,或者从形状为[15, 7]的数据加载形状为[5, 21]的变量。如果形状不以这种方式兼容,那么您必须构建一个新图形。

答案 1 :(得分:0)

通过添加代码

加载元图时将validate_shape设置为False
if graph.node[-1].attr.get("validate_shape"):
    graph.node[-1].attr["validate_shape"].b = False

到tensorflow / python / framework / ops.py#2318

with self._lock:
  graph = graph_pb2.GraphDef()
  graph.versions.CopyFrom(self._graph_def_versions)
  bytesize = 0
  for op_id in sorted(self._nodes_by_id):
    op = self._nodes_by_id[op_id]
    if from_version is None or op_id > from_version:
      graph.node.extend([op.node_def])
      if graph.node[-1].attr.get("validate_shape"):
        graph.node[-1].attr["validate_shape"].b = False
      if op.outputs and add_shapes:
        assert "_output_shapes" not in graph.node[-1].attr
        graph.node[-1].attr["_output_shapes"].list.shape.extend([
            output.get_shape().as_proto() for output in op.outputs])
      bytesize += op.node_def.ByteSize()
      if bytesize >= (1 << 31) or bytesize < 0:
        raise ValueError("GraphDef cannot be larger than 2GB.")