我使用TensorFlow(1.1)高级API Estimators来创建我的神经网络。但我将它用于一个类,我必须调用我的类的实例来生成神经网络的模型。 (这里self.a
)
class NeuralNetwork(object):
def __init__(self):
""" Create neural net """
regressor = tf.estimator.Estimator(model_fn=self.my_model_fn,
model_dir="/tmp/data")
// ...
def my_model_fn(self, features, labels, mode):
""" Generate neural net model """
self.a = a
predictions = ...
loss = ...
train_op = ...
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op)
但我收到错误:
ValueError: model_fn [...] has following not expected args: ['self']
。
我尝试删除模型args的self
,但又出现了另一个错误TypeError: … got multiple values for keyword argument。
有没有办法将这些EstimatorSpec用于一个类?
答案 0 :(得分:1)
看起来Estimator
的参数检查有点过于热心。作为一种解决方法,您可以将成员函数model_fn
包装在lambda
中,如下所示:
import tensorflow as tf
class ModelClass(object):
def __init__(self):
self._constant = 2.
self.regressor = tf.estimator.Estimator(
model_fn=lambda features, labels, mode: self._model_fn(
features, labels, mode))
def _model_fn(self, features, labels, mode):
loss = tf.constant(self._constant)
train_op = tf.no_op()
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op)
ModelClass()
但是,这很烦人。您是否介意filing a feature request on Github放宽此参数检查成员函数?
更新:应在TensorFlow 1.3+中修复。谢谢,袁!