TensorFlow Estimator:model_fn有以下非预期的args:[' self']

时间:2017-05-03 09:02:39

标签: python python-2.7 tensorflow

我使用TensorFlow(1.1)高级API Estimators来创建我的神经网络。但我将它用于一个类,我必须调用我的类的实例来生成神经网络的模型。 (这里self.a

class NeuralNetwork(object):
  def __init__(self):
    """ Create neural net """
    regressor = tf.estimator.Estimator(model_fn=self.my_model_fn,
                                       model_dir="/tmp/data")
    // ...

  def my_model_fn(self, features, labels, mode):
  """ Generate neural net model """
    self.a = a
    predictions = ...
    loss = ...
    train_op = ...
    return tf.estimator.EstimatorSpec(
      mode=mode,
      predictions=predictions,
      loss=loss,
      train_op=train_op)

但我收到错误: ValueError: model_fn [...] has following not expected args: ['self']。 我尝试删除模型args的self,但又出现了另一个错误TypeError: … got multiple values for keyword argument。 有没有办法将这些EstimatorSpec用于一个类?

1 个答案:

答案 0 :(得分:1)

看起来Estimator的参数检查有点过于热心。作为一种解决方法,您可以将成员函数model_fn包装在lambda中,如下所示:

import tensorflow as tf

class ModelClass(object):

  def __init__(self):
    self._constant = 2.
    self.regressor = tf.estimator.Estimator(
        model_fn=lambda features, labels, mode: self._model_fn(
            features, labels, mode))

  def _model_fn(self, features, labels, mode):
    loss = tf.constant(self._constant)
    train_op = tf.no_op()
    return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op)

ModelClass()

但是,这很烦人。您是否介意filing a feature request on Github放宽此参数检查成员函数?

更新:应在TensorFlow 1.3+中修复。谢谢,袁!