如何使用TensorFlow正确创建估算器

时间:2018-06-27 13:08:35

标签: python tensorflow tensorflow-estimator

我想用Python创建一个神经网络,但估算器存在一些问题。

首先,我读了一些有关估算器规范的documentation,并且我认为我正确地创建了估算器类型:

estimate_train = tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)

estimate_test = tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL, loss=loss) 

但是当我想创建用于训练我的网络的估算器时:

estimator_ = tf.estimator.Estimator(model_fn= estimate_train, model_dir="Path")

出现以下错误:

TypeError                                 Traceback (most recent call last)
/usr/lib/python3.5/inspect.py in getfullargspec(func)
   1088                                        skip_bound_arg=False,
-> 1089                                        sigcls=Signature)
   1090     except Exception as ex:

/usr/lib/python3.5/inspect.py in _signature_from_callable(obj, follow_wrapper_chains, skip_bound_arg, sigcls)
   2155     if not callable(obj):
-> 2156         raise TypeError('{!r} is not a callable object'.format(obj))
   2157 

TypeError: EstimatorSpec(mode='eval', predictions={}, loss=<tf.Tensor 'mean_squared_error/value:0' shape=() dtype=float32>, train_op=None, eval_metric_ops={}, export_outputs=None, training_chief_hooks=(), training_hooks=(), scaffold=<tensorflow.python.training.monitored_session.Scaffold object at 0x7f9ddc083748>, evaluation_hooks=(), prediction_hooks=()) is not a callable object

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
<ipython-input-2-c86a69b4da46> in <module>()
     39 
     40 
---> 41 estimatorn = tf.estimator.Estimator(model_fn= estimate_test, model_dir="/home/jabou/Bureau")

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py in __init__(self, model_fn, model_dir, config, params, warm_start_from)
    221     if model_fn is None:
    222       raise ValueError('model_fn must be provided to Estimator.')
--> 223     _verify_model_fn_args(model_fn, params)
    224     self._model_fn = model_fn
    225     self._params = copy.deepcopy(params or {})

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py in _verify_model_fn_args(model_fn, params)
   1212 def _verify_model_fn_args(model_fn, params):
   1213   """Verifies model fn arguments."""
-> 1214   args = set(util.fn_args(model_fn))
   1215   if 'features' not in args:
   1216     raise ValueError('model_fn (%s) must include features argument.' % model_fn)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/util.py in fn_args(fn)
     58     if _is_callable_object(fn):
     59       fn = fn.__call__
---> 60     args = tf_inspect.getfullargspec(fn).args
     61     if _is_bounded_method(fn):
     62       args.remove('self')

/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/tf_inspect.py in getfullargspec(obj)
     88   decorators, target = tf_decorator.unwrap(obj)
     89   return next((d.decorator_argspec for d in decorators
---> 90                if d.decorator_argspec is not None), spec_fn(target))
     91 
     92 

/usr/lib/python3.5/inspect.py in getfullargspec(func)
   1093         # else. So to be fully backwards compatible, we catch all
   1094         # possible exceptions here, and reraise a TypeError.
-> 1095         raise TypeError('unsupported callable') from ex
   1096 
   1097     args = []

TypeError: unsupported callable

这是我完整的代码:

import tensorflow as tf

sess = tf.Session()

tf.reset_default_graph()

batch_size = 20

# Values needed to create the network

input_ = tf.placeholder(tf.float32, shape=(batch_size, 1 , 1 ,1))

filter_ = tf.placeholder(tf.float32, shape=(batch_size, 1 , 2700 ,1))

output_network = tf.placeholder(tf.int32, shape=(4,))

output_real = tf.placeholder(tf.float32)

x_var = tf.get_variable(name = 'x_var', dtype = tf.float32, initializer = tf.random_normal((batch_size,1,1,1), 0, 0.001)) # Initialised values

bias = tf.Variable(tf.zeros([2700]))


# Network

logits = tf.nn.conv2d_transpose(x_var, filter_ ,output_network,[1,1,3,1],'SAME') + bias

loss = tf.losses.mean_squared_error(output_real,logits)  # loss function

optimizer = tf.train.AdamOptimizer(learning_rate=0.001) 

train_op = optimizer.minimize(loss=loss,global_step=tf.train.get_global_step())

# Estimators specification

estimate_train = tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)

estimate_test = tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL, loss=loss)

# Estimator

estimator_ = tf.estimator.Estimator(model_fn= estimate_test, model_dir="Path")

你能帮我吗?

编辑

在回答@ f4之后,我更正了我的代码,但仍然有相同的错误:

import tensorflow as tf

sess = tf.Session()

tf.reset_default_graph()

batch_size = 20

def model(param):

    # Values needed to create the network

    input_ = tf.placeholder(tf.float32, shape=(batch_size, 1 , 1 ,1))

    filter_ = tf.placeholder(tf.float32, shape=(batch_size, 1 , 2700 ,1))

    output_network = tf.placeholder(tf.int32, shape=(4,))

    output_real = tf.placeholder(tf.float32)

    x_var = tf.get_variable(name = 'x_var', dtype = tf.float32, initializer = tf.random_normal((batch_size,1,1,1), 0, 0.001)) # Initialised values

    bias = tf.Variable(tf.zeros([2700]))


    # Network

    logits = tf.nn.conv2d_transpose(x_var, filter_ ,output_network,[1,1,3,1],'SAME') + bias

    loss = tf.losses.mean_squared_error(output_real,logits)  # loss function

    optimizer = tf.train.AdamOptimizer(learning_rate=0.001) 

    train_op = optimizer.minimize(loss=loss,global_step=tf.train.get_global_step())

    # Estimators specification
    if param == "train":
        return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op)

    if param == "test":
        return tf.estimator.EstimatorSpec(mode=tf.estimator.ModeKeys.EVAL, loss=loss)



# Estimator

estimator_ = tf.estimator.Estimator(model_fn= model("train"), model_dir="Path") 

又怎么了?

1 个答案:

答案 0 :(得分:1)

您直接给估算器设置了一个EstimatorSpec,这是不正确的。

model_fn应该是一个返回EstimatorSpec实例的函数。 稍后将调用此函数。结果,它抱怨您所提供的东西无法收回。

编辑

再也不是给函数的返回值,您需要传递的是函数本身,这是可以调用的:

estimator_ = tf.estimator.Estimator(model_fn = model, model_dir="Path") 

您的model_fn也不好。建议您阅读文档https://www.tensorflow.org/get_started/custom_estimators

它应该具有以下签名:

def my_model_fn(
   features, # This is batch_features from input_fn
   labels,   # This is batch_labels from input_fn
   mode,     # An instance of tf.estimator.ModeKeys
   params):  # Additional configuration
  

前两个参数是功能部件和标签的批次   从输入函数返回;也就是说,功能和标签是   处理您的模型将使用的数据。模式参数指示   呼叫者是要求培训,预测还是评估。

您应该使用功能部件和标签,而不要创建临时所有者。 此外,在您的model_fn中,当您不处于训练模式时,不应创建训练操作。