Sketch_RNN,ValueError:无法输入形状的值

时间:2019-01-09 23:17:04

标签: tensorflow magenta

我收到以下错误:

ValueError:无法为形状为“(1,117,5)”的Tensor u'vector_rnn_1 / Placeholder_1:0'输入形状(1、2、251、5)的值

从此处运行代码时 https://github.com/tensorflow/magenta-demos/blob/master/jupyter-notebooks/Sketch_RNN.ipynb

此方法中发生错误:

def encode(input_strokes):
  strokes = to_big_strokes(input_strokes).tolist()
  strokes.insert(0, [0, 0, 1, 0, 0])
  seq_len = [len(input_strokes)]
  draw_strokes(to_normal_strokes(np.array(strokes)))
  return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]

我不得不提到,我按照此处的说明训练了自己的模型:

https://github.com/tensorflow/magenta/tree/master/magenta/models/sketch_rnn

有人可以帮助我理解和解决此问题吗?

谢谢 问候

2 个答案:

答案 0 :(得分:0)

问题是笔触大小与算法期望的数组大小不相等。 因此,修改笔触数组可解决此问题。

答案 1 :(得分:0)

就我而言,问题是由to_big_strokes()函数引起的。如果您不修改sketch_rnn / utils.py中的to_big_stroke(),则默认情况下它将把input_strokes序列延长为250。
您需要做的就是在该函数中修改参数max_len。您需要将该值更改为您自己的数据集的最大序列长度,这对我来说是21,如下面所示的标有“更改”的行。

def to_big_strokes(stroke, max_len=21):  # change: 250 -> 21
  """Converts from stroke-3 to stroke-5 format and pads to given length."""
  # (But does not insert special start token).

  result = np.zeros((max_len, 5), dtype=float)
  l = len(stroke)
  assert l <= max_len
  result[0:l, 0:2] = stroke[:, 0:2]
  result[0:l, 3] = stroke[:, 2]
  result[0:l, 2] = 1 - result[0:l, 3]
  result[l:, 4] = 1
  return result