如何在TensorFlow范围和会话中使用循环?

时间:2017-02-04 01:23:54

标签: python tensorflow

如何解决以下错误?

  

ValueError:变量   RNN / MultiRNNCell / Cell0 / BasicLSTMCell / Linear / Matrix不存在,或   没有使用tf.get_variable()创建。你的意思是设置reuse = None   在VarScope?

当我循环“训练”数据以向基于MultiRNNCell的模型提供批处理数据时,会抛出此错误(请参阅 Main )。 _model(MultiRNNCell)及其组成部分,即下面代码中的BasicLSTMCell(参见模型)不是通过tf.get_variable(...)创建的,因此重复调用模型导致错误。

主要

def run_epoch(session, model, feed_dict):
    logits = model(inputs)
    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
    scores = session.run([logits], feed_dict=feed_dict)

with tf.Graph().as_default():
    # ... Various tensorflow placeholders and variables.

    with tf.variable_scope('Train', reuse=None):
        train_model = lstm_model.MultiRNNLSTM(config, is_training=True)

    with tf.Session() as sess:
        story_cnt = train_stories.shape[0]

        for epoch in range(5):
            print('----- Epoch', epoch, '-----')
            total_loss = 0
            for i in range(story_cnt // BATCH_SIZE):
                inst_story = train_stories[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]
                inst_order = train_orders[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]
                feed_dict = {story: inst_story, order: inst_order}

                run_epoch(sess, train_model, feed_dict)

模型

    # With model class that encapsulates RNN instantiation and calls.

    with tf.variable_scope("RNN"):

        for time_step in range(self._config['num_steps']):
            if time_step > 0:
                print('Reusing RNN variables...')
                tf.get_variable_scope().reuse_variables()

            (output, state) = self._model(inputs[:, time_step, :], state)
            outputs.append(output)

0 个答案:

没有答案