Tensorflow - 停止恢复网络参数

时间:2017-08-21 20:14:45

标签: python machine-learning tensorflow

我尝试从张量流网络进行多个连续预测,但性能似乎非常差(对于2层8x8卷积网络,每个预测约500毫秒),即使对于CPU也是如此。我怀疑问题的一部分是它似乎每次都重新加载网络参数。在下面的代码中对classifier.predict的每次调用都会产生以下输出行 - 因此我会看到数百次。

INFO:tensorflow:Restoring parameters from /tmp/model_data/model.ckpt-102001

如何重用已加载的检查点?

(我不能在这里进行批量预测,因为网络的输出是在游戏中进行的移动,然后需要在进入新游戏状态之前应用于当前状态。)

这是进行预测的循环。

def rollout(classifier, state):
  while not state.terminated:
    predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": state.as_nn_input()}, shuffle=False)
    prediction = next(classifier.predict(input_fn=predict_input_fn))
    index = np.random.choice(NUM_ACTIONS, p=prediction["probabilities"]) # Select a move according to the network's output probabilities
    state.apply_move(index)

classifier是使用...创建的tf.estimator.Estimator

classifier = tf.estimator.Estimator(
      model_fn=cnn_model_fn, model_dir=os.path.join(tempfile.gettempdir(), 'model_data'))

1 个答案:

答案 0 :(得分:0)

Estimator API是一个高级API。

  

tf.estimator框架使构建和训练变得容易   机器学习模型通过其高级Estimator API。估计   提供可以实例化的类,以快速配置通用模型   回归量和分类器等类型。

Estimator API抽象出TensorFlow的大量复杂性,但在此过程中失去了一些普遍性。阅读完代码后,很明显无法在不重新加载模型的情况下运行多个顺序预测。低级TensorFlow API允许此行为。但...

Keras是一个支持此用例的高级框架。简单define the model然后重复调用predict

def rollout(model, state):
  while not state.terminated:
    predictions = model.predict(state.as_nn_input())
    for _, prediction in enumerate(predictions):
      index = np.random.choice(bt.ACTIONS, p=prediction)
      state.apply_mode(index)

不科学的基准测试表明这速度提高了约100倍。