推断阶段的Tensorflow数据集问题

时间:2018-08-28 18:05:10

标签: tensorflow tensorflow-datasets

我使用Tensorflow here创建了字符级语言生成。我使用了tf.placeholder API,根据google docs

  

馈送是将数据馈入TensorFlow程序的效率最低的方法。

我决定更改代码,并用新的TensroFlow Dataset API替换它。

我使用from_generator创建数据集:

dataset = tf.data.Dataset.from_generator(gen, (tf.int32, tf.int32),
                                             (tf.TensorShape([None, None]),
                                              tf.TensorShape([None, None])))
self.iterator = dataset.make_initializable_iterator()
self.inp, self.target = self.iterator.get_next()

从上面的代码中可以看出,我为[None, None]使用了Tensorshape来使模型更具通用性。在培训期间,一切都很好。 但据此推断会出现一些问题。在tf.placeholder API中,我使用了以下代码来生成字符:

def inference(self):
    converter = utils.TextReader(filename=FLAGS.CONVERTER_PATH)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        samples = []
        new_state = sess.run(self.init_state)
        c = 12 # random starting token
        samples.append(c)

        for i in range(1000):
            x = np.zeros((1, 1))
            x[0, 0] = c
            feed_dict = {
                self.inp: x,
                self.init_state: new_state
            }
            preds, new_state = sess.run([self.prediction, self.final_state], feed_dict=feed_dict)
            c = utils.pick_top_n(preds, converter.vocab_size)
            samples.append(c)

        samples = np.array(samples)
        print(converter.arr_to_text(samples))

在数据集API中,我没有tf.placeholder来填充我以前的字符。当我按预期使用上面的代码时,发生了以下错误:

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [1,50] vs. shape[1] = [32,50]

推断时,模型使用与训练时相同的输入形状([32,50])。这不是我想要的(实际上,我定义了TensorShape([None,None])来处理此问题,但不起作用)。

如何使用新的数据集API解决此问题?

Complete code

0 个答案:

没有答案