具有numpy数组的估算器input_fn

时间:2017-12-20 23:59:34

标签: python numpy tensorflow lstm recurrent-neural-network

我正在创建一个带有numpy数组的估算器,以便使用tf.estimator.inputs.numpy_input_fn提供给模型。如下所示:

def input_fun(data):
    x, y = data

    x, y = np.reshape(x, (batch_size, -1, 1)), \
           np.reshape(y, (batch_size, -1, 1))

    return tf.estimator.inputs.numpy_input_fn({'x': x}, y)

def forward(x, params, mode):

    layers = [tf.nn.rnn_cell.LSTMCell(n_neurons) for _ in range(n_layers)]
    cells = tf.nn.rnn_cell.MultiRNNCell(layers)
    outputs, state = tf.nn.dynamic_rnn(cells, x)

    predictions = ...

    return predictions

def model_fn(features, labels, mode, params):
    predict = forward(features, params, mode)

    return tf.estimator.EstimatorSpec(predict , ...)

def experiment_fn(config, params):
    return learn.Experiment(
        estimator = estimator(model_fn,...),
        train_input_fn = lambda: input_fun(train_set),
        eval_input_fn = lambda: input_fun(eval_set))

它引发了以下内容:

  

追踪(最近一次呼叫最后一次):

     

文件“”,第1行,in       RUNFILE( '/ Experiment.py',   WDIR = '/ TensorFlow')

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ Spyder的\ utils的\网站\ sitecustomize.py”   第710行,在runfile中       execfile(filename,namespace)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ Spyder的\ utils的\网站\ sitecustomize.py”   第101行,在execfile中       exec(compile(f.read(),filename,'e​​xec'),namespace)

     

文件“/Experiment.py”,第490行,in       hparams = params

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \的contrib \学习\ python的\学习\ learn_runner.py”   第218行,在运行中       return _execute_schedule(experiment,schedule)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \的contrib \学习\ python的\学习\ learn_runner.py”   第46行,在_execute_schedule中       return task()

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \的contrib \学习\ python的\学习\ experiment.py”   第367行,在火车上       hooks = self._train_monitors + extra_hooks)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \的contrib \学习\ python的\学习\ experiment.py”   第807行,在_call_train中       钩=钩)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\估计\ estimator.py”   302号线,在火车上       loss = self._train_model(input_fn,hooks,saving_listeners)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\估计\ estimator.py”   第711行,在_train_model中       功能,标签,model_fn_lib.ModeKeys.TRAIN,self.config)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\估计\ estimator.py”   第694行,在_call_model_fn中       model_fn_results = self._model_fn(features = features,** kwargs)

     

文件“/Experiment.py”,第350行,在model_fn中       预测=前进(特征,参数,模式)

     

文件“/Experiment.py”,第335行,正向       dtype = tf.float32

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\ OPS \ rnn.py”   第562行,在dynamic_rnn中       flat_input = [op_convert_to_tensor(input_)for input_ in flat_input]

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\ OPS \ rnn.py”   第562行       flat_input = [op_convert_to_tensor(input_)for input_ in flat_input]

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ ops.py”   第836行,在convert_to_tensor中       as_ref =假)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ ops.py”   第926行,在internal_convert_to_tensor中       ret = conversion_func(value,dtype = dtype,name = name,as_ref = as_ref)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ constant_op.py”   第229行,在_constant_tensor_conversion_function中       返回常量(v,dtype = dtype,name = name)

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ constant_op.py”   208行,常数       value,dtype = dtype,shape = shape,verify_shape = verify_shape))

     

文件   “C:\用户\ HP \ Anaconda3 \ LIB \站点包\ tensorflow \ python的\框架\ tensor_util.py”   第472行,在make_tensor_proto中       “支持类型。” %(类型(值),值))

     

TypeError:无法转换类型为< class'function'>的对象至   张量。内容:< function numpy_input_fn。< locals> .input_fn at   0x000001AB2B1DBEA0取代。考虑将元素转换为支持的类型。

有谁知道为什么?

2 个答案:

答案 0 :(得分:1)

我有类似的问题。在我的情况下引发了异常,因为在我的模型中(我猜"转发",在你的情况下)x被用作Tensor,但它实际上是一个函数(特别是tf.estimator.inputs.numpy_input_fn) 。 我想通过添加这个来解决这个问题:

print(x)
print(type(x))

其中印有这样的东西:

<function numpy_input_fn.<locals>.input_fn at 0x7fcc6f065740>
<class 'function'>

我仍然不确定解决问题的正确方法是什么,但我能够通过做类似的事情来修复它:

input_dict, y = x()
x = input_dict['x']

希望有所帮助

答案 1 :(得分:0)

您应该将单元格列表传递到MultiRNNCell

  

参数数量:

     

cells :将按此顺序组成的RNNCell列表。

     

state_is_tuple :如果为True,则接受和返回的状态为n元组,   其中n = len(cells)。如果为False,则状态都连接在一起   列轴。后一种行为很快就会被弃用。

如果您真的想制作单层RNN,请将代码更改为

cells = tf.nn.rnn_cell.MultiRNNCell([layers])

...或创建更多图层。