Tensorflow,使用Estimator(model_fn)提供占位符?

时间:2016-12-03 00:48:46

标签: python machine-learning tensorflow

我正在尝试构建一个lstm模型。但是我得到了

 
        InvalidArgumentError Traceback (most recent call last)
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
        964     try:
    --> 965       return fn(*args)
        966     except errors.OpError as e:
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
        946                                  feed_dict, fetch_list, target_list,
    --> 947                                  status, run_metadata)
        948 
    /home/george/anaconda3/lib/python3.5/contextlib.py in exit(self, type, value, traceback)
         65             try:
    ---> 66                 next(self.gen)
         67             except StopIteration:
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/errors.py in raise_exception_on_not_ok_status()
        449           compat.as_text(pywrap_tensorflow.TF_Message(status)),
    --> 450           pywrap_tensorflow.TF_GetCode(status))
        451   finally:
    InvalidArgumentError: You must feed a value for placeholder tensor 'input' with dtype float
         [[Node: input = Placeholderdtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]]
    During handling of the above exception, another exception occurred:
    InvalidArgumentError                      Traceback (most recent call last)
     in ()
          1 classificator.fit(X_train_TF, Y_train, monitors = [validation_monitor],
    ----> 2                   batch_size = batch_size, steps = training_steps)
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py in fit(self, x, y, input_fn, steps, batch_size, monitors, max_steps)
        217                              steps=steps,
        218                              monitors=monitors,
    --> 219                              max_steps=max_steps)
        220     logging.info('Loss for final step: %s.', loss)
        221     return self
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py in _train_model(self, input_fn, steps, feed_fn, init_op, init_feed_fn, init_fn, device_fn, monitors, log_every_steps, fail_on_nan_loss, max_steps)
        477       features, targets = input_fn()
        478       self._check_inputs(features, targets)
    --> 479       train_op, loss_op = self._get_train_ops(features, targets)
        480 
        481       # Add default monitors.
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py in _get_train_ops(self, features, targets)
        747       Tuple of train Operation and loss Tensor.
        748     """
    --> 749     _, loss, train_op = self._call_model_fn(features, targets, ModeKeys.TRAIN)
        750     return train_op, loss
        751 
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py in _call_model_fn(self, features, targets, mode)
        731       else:
        732         return self._model_fn(features, targets, mode=mode)
    --> 733     return self._model_fn(features, targets)
        734 
        735   def _get_train_ops(self, features, targets):
    /home/george/ipython/project/lstm_model.py in model(X, y)
         61         output = lstm_layers(output[-1],dense_layers)
         62         prediction, loss = tflearn.run_n({"outputs": output, "last_states": layers}, n=1,
    ---> 63                                         feed_dict=None)
         64         train_operation = tflayers.optimize_loss(loss, tf.contrib.framework.get_global_step(), optimizer=optimizer,
         65                                                  learning_rate=learning_rate)
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/graph_actions.py in run_n(output_dict, feed_dict, restore_checkpoint_path, n)
        795       output_dict=output_dict,
        796       feed_dicts=itertools.repeat(feed_dict, n),
    --> 797       restore_checkpoint_path=restore_checkpoint_path)
        798 
        799 
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/graph_actions.py in run_feeds(*args, **kwargs)
        850 def run_feeds(*args, **kwargs):
        851   """See run_feeds_iter(). Returns a list instead of an iterator."""
    --> 852   return list(run_feeds_iter(*args, **kwargs))
        853 
        854 
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/graph_actions.py in run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path)
        841         threads = queue_runner.start_queue_runners(session, coord=coord)
        842         for f in feed_dicts:
    --> 843           yield session.run(output_dict, f)
        844       finally:
        845         coord.request_stop()
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
        708     try:
        709       result = self._run(None, fetches, feed_dict, options_ptr,
    --> 710                          run_metadata_ptr)
        711       if run_metadata:
        712         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
        906     if final_fetches or final_targets:
        907       results = self._do_run(handle, final_targets, final_fetches,
    --> 908                              feed_dict_string, options, run_metadata)
        909     else:
        910       results = []
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
        956     if handle is None:
        957       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
    --> 958                            target_list, options, run_metadata)
        959     else:
        960       return self._do_call(_prun_fn, self._session, handle, feed_dict,
    /home/george/anaconda3/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
        976         except KeyError:
        977           pass
    --> 978       raise type(e)(node_def, op, message)
        979 
        980   def _extend_graph(self):

InvalidArgumentError: You must feed a value for placeholder tensor 'input' with dtype float
     [[Node: input = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Operation

    这是我的笔记本代码,当我调用第三行(适合分类器)时发生错误:     
Tensor
有我的模型定义(一点视觉错误,但所有函数都在lstm_model函数中): list
InvalidArgumentError: You must feed a value for placeholder tensor 'input' with dtype float
     [[Node: input = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

我做错了什么?有什么建议如何解决这个问题?非常感谢未来的建议)

1 个答案:

答案 0 :(得分:-1)

.fit 方法已被修改为支持input_fn作为参数,而不是训练数据及其标签。请查看this示例。