我正在关注_tensorflow.org上的this教程。 我正在尝试正确处理 input_fn _ ,以在 .fit()中用作参数。 我创建了分类器:
classifier = tf.contrib.learn.SKCompat(tf.contrib.learn.DNNClassifier(
feature_columns=feature_cols,
hidden_units=[10, 10],
model_dir=("C:\\........\tmp"),
n_classes=2,
activation_fn=tf.sigmoid,
optimizer=tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001
)))
然后输入功能:
def input_fn(data_set):
feature_cols = {k: tf.constant(data_set[k].values)
for k in FEATURES}
labels = tf.constant(data_set[LABEL].values)
return feature_cols, labels
最后我将 input_fn()放在 fit()中:
classifier.fit(input_fn=lambda: input_fn(training_set), steps=10)
当我运行代码时,我收到此错误:
TypeError Traceback (most recent call last)
<ipython-input-6-938bcd2f929f> in <module>()
----> 1 classifier.fit(input_fn=lambda: input_fn(training_set), steps=10)
TypeError: fit() got an unexpected keyword argument 'input_fn'
我不知道它是关于 input_fn 定义还是 fit 参数
答案 0 :(得分:0)
如果您想使用input_fn,请不要使用SKCompat,将第一行替换为:
classifier = tf.contrib.learn.DNNClassifier(
根据需要调整括号。