tensorflow-使用estimator API

时间:2018-11-20 23:47:30

标签: python tensorflow deep-learning reinforcement-learning tensorflow-estimator

我尝试使用experience replay memory API来实现tf.estimator.Estimator。但是,我不确定什么是在所有模式(TRAINEVALUATEPREDICT)下都可以使用的效果最好的最佳方法是什么。我尝试了以下方法:

  • tf.Variable填充内存,这会导致批处理和输入管道出现问题(我无法在测试或预测阶段输入自定义体验)

,目前尝试:

  • tf.Graph外部实现内存。每次运行后使用tf.train.SessionRunHook设置值。在训练和测试过程中,使用tf.data.Dataset.from_generator()加载体验。自己管理状态。

我在几点上都失败了,并且开始相信tf.estimator.Estimator API并没有为我提供必要的接口来轻松写下来。

一些代码(第一种方法,它的批处理大小失败,因为它是为exp的切片而固定的,因此我无法将模型用于预测或评估):

 def model_fn(self, features, labels, mode, params):
    batch_size = features["matrix"].get_shape()[0].value

    # get prev_exp
    if mode == tf.estimator.ModeKeys.TRAIN:
        erm = tf.get_variable("erm", shape=[30000, 10], initializer=tf.constant_initializer(self.erm.initial_train_erm()), trainable=False)
        prev_exp = tf.slice(erm, [features["index"][0], 0], [batch_size, 10])

    # model
    pred = model(features["matrix"], prev_exp, params) 

但是:将erm包含在功能字典中会更好。但是然后我必须在图表外管理erm,并用SessionRunHook回写我的最新经验。有什么更好的方法还是我错过了什么?

1 个答案:

答案 0 :(得分:0)

我通过在图形外部实现ERM,使用tf.data.Dataset.from_generator()将其反馈回输入管道并使用SessionRunHooks回写来解决了我的问题。是的,非常乏味,但是可以。