在Estimator API中复制的SessionRunHook实例

时间:2018-10-02 13:16:34

标签: python tensorflow

像TF一样,我正在创建我的钩子的副本,而这没有记录,所以我想知道语义是什么吗?我基本上有一个想在我的model_fn中注册一些数据的钩子。我的main()中目前有一个类似于以下内容的代码段:

my_hook = MyHook()
estimator = tf.estimator.Estimator(..., params={'my_hook': my_hook})

这里MyHooktf.train.SessionRunHook的子类。在我的model_fn中,我可以像这样打印MyHook实例的唯一ID:

def model_fn(features, labels, mode, params):
    my_hook = params['my_hook']
    print(id(my_hook))
    ...

在上一次运行中,我得到的值是140384973699344。另一方面,如果我在after_create_session()的{​​{1}}中添加以下内容:

MyHook

在同一运行中,将打印def after_create_session(self, session, coord): print(id(self)) ... 。换句话说,它不再是同一实例。这是一个问题,因为我的140384991235920正在使用要获取的钩子实例注册某些数据,因此在训练估算器时,该数据不再可用。除了在全局寄存器中注册数据,还有其他方法吗?

0 个答案:

没有答案