Tensorflow:检索元图时修改占位符的形状

时间:2016-08-18 15:19:23

标签: neural-network tensorflow deep-learning lstm

我训练了一个递归神经网络(LSTM)并保存了权重和元图。当我检索元数据以进行预测时,只要序列长度与训练期间相同,一切都可以正常工作。

LSTM的一个好处是输入的序列长度可以变化(例如,如果输入是形成句子的字母,则句子的长度可以变化)。

从元图中检索图形时,如何更改输入的序列长度?

有关代码的更多详细信息:

在培训期间,我使用占位符xy来提供数据。对于预测,我检索这些占位符但无法设法更改其形状(从[None, previous_sequence_length=100, n_input][None, new_sequence_length=50, n_input])。

在文件model.py中,定义体系结构和占位符:

 self.x = tf.placeholder("float32", [None, self.n_steps, self.n_input], name='x_input')
 self.y = tf.placeholder("float32", [None, self.n_classes], name='y_labels')
 tf.add_to_collection('x', self.x)
 tf.add_to_collection('y', self.y)
 ...

 def build_model(self):
     #using the placeholder self.x to build the model
     ...
     tf.split(0, self.n_input, self.x) # split input for RNN cell
     ...

在文件prediction.py中我检索预测的元图:

with tf.Session() as sess:
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir)
    new_saver = tf.train.import_meta_graph(latest_checkpoint + '.meta')
    new_saver.restore(sess, latest_checkpoint)
    x = tf.get_collection('x')[0]
    y = tf.get_collection('y')[0]
    ...
    sess.run(..., feed_dict={x: batch_x})

这是我得到的错误:

ValueError: Cannot feed value of shape (128, 50, 2) for Tensor u'placeholders/x_input:0', which has shape '(?, 100, 2)'

注意:当不使用元图而是从头开始重建模型并仅加载保存的权重(而不是元图)时,我设法解决了这个问题。

编辑:将self.n_steps替换为None并使用tf.split(0, self.n_input, self.x)修改tf.split(0, self.x.get_shape()[1], self.x)时出现以下错误:TypeError: Expected int for argument 'num_split' not Dimension(None). < / p>

1 个答案:

答案 0 :(得分:2)

当你定义你的变量时,我建议你把它写成如下

[None, None, n_input]

而不是:

[None, new_sequence_length=50, n_input]

它适用于我的情况。我希望它有所帮助