使用张量流量进行批处理不能按预期工作

时间:2017-06-15 12:12:51

标签: tensorflow batching

我使用以下代码使用Tensorflow contrib库进行批处理。

def input_fn_batch(batch_size, train_data):
  """Input builder function."""
  default = [tf.constant([''], dtype=tf.string)] * len(COLUMNS)
  base_data_values = tf.contrib.learn.read_batch_examples([train_data],
                                                          batch_size=batch_size,
                                                          reader=tf.TextLineReader,
                                                          num_epochs=1,
                                                          parse_fn=lambda x: tf.decode_csv(x, record_defaults=default))

  df_train = {}
  for i, column in enumerate(COLUMNS):
    df_train[column] = base_data_values[:, i]
  for column in CATEGORICAL_INT_COLUMNS:
    df_train[column] = tf.string_to_number(df_train[column], out_type=tf.int32)


  # Creates a dictionary mapping from each continuous feature column name (k) to
  # the values of that column stored in a constant Tensor.
  continuous_cols = {k: tf.string_to_number(df_train[k])
                   for k in CONTINUOUS_COLUMNS}

  # Creates a dictionary mapping from each categorical feature column name (k)
  # to the values of that column stored in a tf.SparseTensor.
  categorical_cols = {k: dense_to_sparse(df_train[k])
                    for k in CATEGORICAL_COLUMNS}

  # Merges the two dictionaries into one.
  feature_cols = dict(continuous_cols)
  feature_cols.update(categorical_cols)
  # Converts the label column into a constant Tensor.
  label = tf.string_to_number(df_train[LABEL_COLUMN], out_type=tf.int32)

  # Returns the feature columns and the label.
  return feature_cols, label


def dense_to_sparse(dense_tensor):
  indices = tf.to_int64(tf.transpose([tf.range(tf.shape(dense_tensor)[
                          0]), tf.zeros_like(dense_tensor, dtype=tf.int32)]))
  values = dense_tensor
  shape = tf.to_int64([tf.shape(dense_tensor)[0], tf.constant(1)])

  return tf.SparseTensor(
        indices=indices,
        values=values,
        dense_shape=shape
    )

我将fit函数调用如下

estimator.fit(input_fn=lambda: input_fn_batch(1000,train_data), steps=200)

由于某种原因,它只执行一步。它好像忽略了steps参数。

0 个答案:

没有答案