我按照谷歌(https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/tensorflow/d_experiment.ipynb)的指南建立一个简单的线性回归模型。
在笔记本中,它使用了Experiment
类和learn_runner
(我找不到任何文档的类)来训练模型。我现在正在尝试使用该模型进行预测。我尝试了以下但是我收到了一个错误。你能告诉我正确的方法吗?感谢。
代码添加到底部:
# load the saved model
estimator = tflearn.LinearRegressor(feature_columns=feature_cols, model_dir='taxi_trained')
estimator.predict(input_fn=get_test)
得到错误:
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_is_chief': True, '_model_dir': None, '_save_checkpoints_secs': 600, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x00000218611630F0>, '_master': '', '_task_id': 0, '_keep_checkpoint_every_n_hours': 10000, '_evaluation_master': '', '_environment': 'local', '_num_worker_replicas': 0, '_tf_random_seed': None, '_tf_config': gpu_options {
per_process_gpu_memory_fraction: 1
}
, '_save_checkpoints_steps': None, '_keep_checkpoint_max': 5, '_task_type': None, '_num_ps_replicas': 0, '_save_summary_steps': 100}
WARNING:tensorflow:From c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\python\util\deprecation.py:335: calling LinearRegressor.predict (from tensorflow.contrib.learn.python.learn.estimators.linear) with outputs=None is deprecated and will be removed after 2017-03-01.
Instructions for updating:
Please switch to predict_scores, or set `outputs` argument.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-5-7f1903437174> in <module>()
1 with tf.Session() as sess:
2 estimator = tflearn.LinearRegressor(feature_columns=feature_cols, model_dir='taxi_trained')
----> 3 estimator.predict(input_fn=get_test)
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
333 _call_location(), decorator_utils.get_qualified_name(func),
334 func.__module__, arg_name, arg_value, date, instructions)
--> 335 return func(*args, **kwargs)
336 new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
337 func.__doc__, date, instructions)
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
333 _call_location(), decorator_utils.get_qualified_name(func),
334 func.__module__, arg_name, arg_value, date, instructions)
--> 335 return func(*args, **kwargs)
336 new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
337 func.__doc__, date, instructions)
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\linear.py in predict(self, x, input_fn, batch_size, outputs, as_iterable)
755 input_fn=input_fn,
756 batch_size=batch_size,
--> 757 as_iterable=as_iterable)
758 return super(LinearRegressor, self).predict(
759 x=x,
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
333 _call_location(), decorator_utils.get_qualified_name(func),
334 func.__module__, arg_name, arg_value, date, instructions)
--> 335 return func(*args, **kwargs)
336 new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
337 func.__doc__, date, instructions)
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\linear.py in predict_scores(self, x, input_fn, batch_size, as_iterable)
790 batch_size=batch_size,
791 outputs=[key],
--> 792 as_iterable=as_iterable)
793 if as_iterable:
794 return _as_iterable(preds, output=key)
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
279 _call_location(), decorator_utils.get_qualified_name(func),
280 func.__module__, arg_name, date, instructions)
--> 281 return func(*args, **kwargs)
282 new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
283 func.__doc__, date, instructions)
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\estimator.py in predict(self, x, input_fn, batch_size, outputs, as_iterable)
563 feed_fn=feed_fn,
564 outputs=outputs,
--> 565 as_iterable=as_iterable)
566
567 def get_variable_value(self, name):
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\estimator.py in _infer_model(self, input_fn, feed_fn, outputs, as_iterable, iterate_batches)
855 contrib_framework.create_global_step(g)
856 features = self._get_features_from_input_fn(input_fn)
--> 857 infer_ops = self._get_predict_ops(features)
858 predictions = self._filter_predictions(infer_ops.predictions, outputs)
859 mon_sess = monitored_session.MonitoredSession(
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\estimator.py in _get_predict_ops(self, features)
1186 labels = tensor_signature.create_placeholders_from_signatures(
1187 self._labels_info)
-> 1188 return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER)
1189
1190 def export_savedmodel(
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\estimator.py in _call_model_fn(self, features, labels, mode)
1101 if 'model_dir' in model_fn_args:
1102 kwargs['model_dir'] = self.model_dir
-> 1103 model_fn_results = self._model_fn(features, labels, **kwargs)
1104
1105 if isinstance(model_fn_results, model_fn_lib.ModelFnOps):
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\linear.py in _linear_model_fn(features, labels, mode, params, config)
159 num_outputs=head.logits_dimension,
160 weight_collections=[parent_scope],
--> 161 scope=scope)
162
163 def _train_op_fn(loss):
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\layers\python\layers\feature_column_ops.py in weighted_sum_from_feature_columns(columns_to_tensors, feature_columns, num_outputs, weight_collections, trainable, scope)
529 # pylint: disable=protected-access
530 for column in sorted(set(feature_columns), key=lambda x: x.key):
--> 531 transformed_tensor = transformer.transform(column)
532 try:
533 embedding_lookup_arguments = column._wide_embedding_lookup_arguments(
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\layers\python\layers\feature_column_ops.py in transform(self, feature_column)
880 return self._columns_to_tensors[feature_column]
881
--> 882 feature_column.insert_transformed_feature(self._columns_to_tensors)
883
884 if feature_column not in self._columns_to_tensors:
c:\users\tommy\appdata\local\programs\python\python35\lib\site-packages\tensorflow\contrib\layers\python\layers\feature_column.py in insert_transformed_feature(self, columns_to_tensors)
1406 """
1407 # Transform the input tensor according to the normalizer function.
-> 1408 input_tensor = self._normalized_input_tensor(columns_to_tensors[self.name])
1409 columns_to_tensors[self] = math_ops.to_float(input_tensor)
1410
KeyError: 'dropofflat'
我在Windows 10上使用TensorFlow 1.1和Python 3.5启用。