tensorflow DNNLinearCombinedClassifier列车批次

时间:2017-07-02 04:56:53

标签: python dll machine-learning tensorflow classification

我的代码:

def batch_input_fn(df,batch_size):
    def _input_fn(  ):
      """Input builder function."""
      # Creates a dictionary mapping from each continuous feature column name (k) to
      # the values of that column stored in a constant Tensor.
      continuous_cols = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLUMNS}
      # Creates a dictionary mapping from each categorical feature column name (k)
      # to the values of that column stored in a tf.SparseTensor.
      categorical_cols = {
          k: tf.SparseTensor(
              indices=[[i, 0] for i in range(df[k].size)],
              values=df[k].values.astype(str),
              dense_shape=[df[k].size, 1])
          for k in CATEGORICAL_COLUMNS}
      # Merges the two dictionaries into one.
      feature_cols = dict(continuous_cols)
      feature_cols.update(categorical_cols)
      # Converts the label column into a constant Tensor.
      label = tf.constant(df["label"].values.astype(int))
      # Returns the feature columns and the label.
      sliced_input = tf.train.slice_input_producer([feature_cols, label])
      return tf.train.batch(sliced_input, batch_size=batch_size)
    return _input_fn

这是错误信息:

  File "D:/code/mobike_end_loc/slover/wide_n_deep_V2.py", line 117, in make_sub
    m.fit(input_fn=lambda: batch_input_fn(df_train.head(len(df_train)-1000),256), steps=train_steps)
  File "D:\code\Zhihu_KanShanBei\venv\lib\site-packages\tensorflow\python\util\deprecation.py", line 281, in new_func
    return func(*args, **kwargs)
  File "D:\code\Zhihu_KanShanBei\venv\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\estimator.py", line 430, in fit
    loss = self._train_model(input_fn=input_fn, hooks=hooks)
  File "D:\code\Zhihu_KanShanBei\venv\lib\site-packages\tensorflow\contrib\learn\python\learn\estimators\estimator.py", line 925, in _train_model
    features, labels = input_fn()
TypeError: 'function' object is not iterable

1 个答案:

答案 0 :(得分:0)

here is my code 

def batch_input_fn(df,batch_size): def _input_fn( ): """Input builder function.""" # Creates a dictionary mapping from each continuous feature column name (k) to # the values of that column stored in a constant Tensor. continuous_cols = {k: tf.constant(df[k].values) for k in CONTINUOUS_COLUMNS} # Creates a dictionary mapping from each categorical feature column name (k) # to the values of that column stored in a tf.SparseTensor. categorical_cols = { k: tf.SparseTensor( indices=[[i, 0] for i in range(df[k].size)], values=df[k].values.astype(str), dense_shape=[df[k].size, 1]) for k in CATEGORICAL_COLUMNS} # Merges the two dictionaries into one. feature_cols = dict(continuous_cols) feature_cols.update(categorical_cols) # Converts the label column into a constant Tensor. label = tf.constant(df["label"].values.astype(int)) # Returns the feature columns and the label. sliced_input = tf.train.slice_input_producer([feature_cols, label]) return tf.train.batch(sliced_input, batch_size=batch_size) return _input_fn