直接向tf.contrib.learn估算者提供输入

时间:2017-03-14 13:53:51

标签: machine-learning tensorflow neural-network

我在使用TensorFlow的tf.contrib.learn中的DNNRegressor估算器时遇到了问题。在估算器的documentation page中,提出了两种提供输入的方法。

第一种方法使用 input_fn 函数,如上所述,该函数应用于预处理并将输入提供给估算器,第二种方法直接提供输入。例子:

def input_function:
    ...
    return feature_cols, label

estimator.fit(input_fn=input_function, steps=...)

在这种情况下, feature_cols dict ,其中包含:

  • 键: string 指定列名
  • 值: tf.constant 指定列值

标签是包含标签的单个 tf.constant 列。

这很有用。

X_train = ...
y_train = ...
estimator.fit(x=X_train, y=y_train, steps=...)

在这种情况下,我不知道要以 X y 的形式提供什么。我尝试过以下方法:

  • 普通的numpy数组。这是一个很长的镜头,并没有与之合作。错误消息为:KeyError: 'my_column0'
  • Pandas DataFrame,其列对应于已定义的列名(在估算器的初始化时定义)。我会再次获得相同的KeyError,即使密钥现在应该存在。
  • 传递X = feature_cols 和y = 标签,其定义方式与上述input_fn相同。这会产生:ValueError: Inputs cannot be tensors. Please provide input_fn.

我还尝试了与 dict 和numpy数组的其他组合,但没有任何效果。我希望能够使用第二种方法完成这项工作,因为这对于将对象传递给 evaluate predict 也很有用。有谁知道这个的正确格式? 还有,有没有办法简单地传递numpy数组?

谢谢!

TL;博士 tf.contrib.learn估算工具的输入应该是什么才能使用estimator.fit(x=X_train, y=y_train, steps=...)直接提供?

1 个答案:

答案 0 :(得分:0)

我想我错过了以下警告: 有些论点已被弃用。它们将在2016-12-01之后删除。更新说明:通过迁移到单独的SKCompat类,Estimator与Scikit Learn界面分离。参数x,y和batch_size仅在SKCompat类中可用,Estimator仅接受input_fn。示例转换:est = Estimator(...) - > est = SKCompat(Estimator(...))

奇怪的是,在2016-12-01之后,论点仍然存在。