Tensorflow-恢复模型会消耗后续迭代中的所有内存

时间:2017-07-30 09:46:23

标签: machine-learning tensorflow artificial-intelligence

我正在尝试训练一个简单的神经网络,我需要在其中保存模型,加载新数据集并恢复模型。这项工作保存 - 恢复过程在3或4次迭代后消耗了我的所有内存。这是我的代码的相关部分。 runsess()函数在循环中多次迭代。

num_steps = 50

with tf.Session(graph=graph) as session:
   tf.global_variables_initializer().run()
   print("Initialized")
   for step in range(num_steps):
     offset = (step * batch_size) % (train_labels.shape[0] - batch_size)
     batch_data = train_dataset[offset:(offset + batch_size)]
     batch_labels = train_labels[offset:(offset + batch_size)]
     feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels}

     _, l, predictions = session.run([optimizer, loss, train_prediction], feed_dict=feed_dict)

     if (step % 10 == 0):
       print("Minibatch loss at step %d: %f" % (step, l))
       print("Minibatch accuracy: %.1f%%" % accuracy(predictions, batch_labels))
       print("Validation accuracy: %.1f%%\n" % accuracy(valid_prediction.eval(), valid_labels))


   print("Test accuracy: %.1f%%" % accuracy(test_prediction.eval(), test_labels))

   # Saving the session
   saver = tf.train.Saver()
   save_path = "./checkpoints/model.ckpt"
   saver.save(session, save_path)
   print("Model saved in file: %s" % save_path)
   session.close()


def runsess(graph,num_steps):
  with tf.Session(graph=graph) as session:
    saver = tf.train.import_meta_graph('./checkpoints/model.ckpt.meta')
    saver.restore(session,tf.train.latest_checkpoint('./checkpoints/'))
    tf.global_variables_initializer()
    print("Initialized")
    tf.get_default_graph()
    for step in range(num_steps):
      offset = (step * batch_size) % (train_labels.shape[0] - batch_size)
      batch_data = train_dataset[offset:(offset + batch_size)]
      batch_labels = train_labels[offset:(offset + batch_size)]
      feed_dict = {tf_train_dataset : batch_data, tf_train_labels : batch_labels}

      _, l, predictions = session.run([optimizer, loss, train_prediction], feed_dict=feed_dict)
      if (step % 10 == 0):
        print("Minibatch loss at step %d: %f" % (step, l))
        print("Minibatch accuracy: %.1f%%" % accuracy(predictions, batch_labels))
        print("Validation accuracy: %.1f%%\n" % accuracy(valid_prediction.eval(), valid_labels))
    print("Test accuracy: %.1f%%" % accuracy(test_prediction.eval(), test_labels))

    saver.save(session, save_path)
    session.close()

在runsess()中保存模型时我似乎犯了一个错误,但我不明白在哪里以及如何。我该如何解决这个问题?

1 个答案:

答案 0 :(得分:0)

问题是您调用在图表中创建新操作的函数,从而导致内存消耗。

首先,您应该拥有一个保护程序,您无法在循环中创建它,每个保护程序都会创建分配操作。

saver = tf.train.import_meta_graph('./checkpoints/model.ckpt.meta')
saver.restore(session,tf.train.latest_checkpoint('./checkpoints/'))

其次,global_variables_initializer也会创建一个op(实际上很多的操作),但是你调用它并且甚至不存储结果。尽管有这个名字,但这个函数不会初始化变量 - 它会创建一个op来执行此操作 - 只需在您的情况下删除它。

tf.global_variables_initializer()

了解哪些函数修改图表或不修改图表的最简单方法是在完成构建后始终完成图表。这样每次调用一个创建新操作的函数时 - 你会得到一个异常告诉你,你将能够调试它。

tf.get_default_graph().finalize()