我使用Tensorflow here创建了字符级语言生成。我使用了tf.placeholder
API,根据google docs:
馈送是将数据馈入TensorFlow程序的效率最低的方法。
我决定更改代码,并用新的TensroFlow Dataset API替换它。
我使用from_generator创建数据集:
dataset = tf.data.Dataset.from_generator(gen, (tf.int32, tf.int32),
(tf.TensorShape([None, None]),
tf.TensorShape([None, None])))
self.iterator = dataset.make_initializable_iterator()
self.inp, self.target = self.iterator.get_next()
从上面的代码中可以看出,我为[None, None]
使用了Tensorshape
来使模型更具通用性。在培训期间,一切都很好。 但据此推断会出现一些问题。在tf.placeholder
API中,我使用了以下代码来生成字符:
def inference(self):
converter = utils.TextReader(filename=FLAGS.CONVERTER_PATH)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
samples = []
new_state = sess.run(self.init_state)
c = 12 # random starting token
samples.append(c)
for i in range(1000):
x = np.zeros((1, 1))
x[0, 0] = c
feed_dict = {
self.inp: x,
self.init_state: new_state
}
preds, new_state = sess.run([self.prediction, self.final_state], feed_dict=feed_dict)
c = utils.pick_top_n(preds, converter.vocab_size)
samples.append(c)
samples = np.array(samples)
print(converter.arr_to_text(samples))
在数据集API中,我没有tf.placeholder
来填充我以前的字符。当我按预期使用上面的代码时,发生了以下错误:
InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [1,50] vs. shape[1] = [32,50]
推断时,模型使用与训练时相同的输入形状([32,50]
)。这不是我想要的(实际上,我定义了TensorShape([None,None])来处理此问题,但不起作用)。
如何使用新的数据集API解决此问题?