TensorFlow - 使用字符串字段

时间:2017-12-18 22:24:31

标签: python tensorflow

我是Tensorflow的新手并尝试在包含60列的.csv文件上运行神经网络。但是其中一些包含字符串字段。我试着运行我得到的程序could not convert string to float:这是代码。

# Load datasets.
  training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
      filename=TRAINING,
      target_dtype=np.int,
      features_dtype=np.float32)

  test_set = tf.contrib.learn.datasets.base.load_csv_without_header(
      filename=TEST,
      target_dtype=np.int,
      features_dtype=np.float32)

  # Specify that all features have real-value data
  feature_columns = [tf.feature_column.numeric_column("x", shape=[59])]


  classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                          hidden_units=[59],
                                          n_classes=2)

现在我读到target_dtype和features_dtype采用numpy类型。我在这里搜索了https://docs.scipy.org/doc/numpy/user/basics.types.html,看起来他们没有字符串字段。 实现这一目标的最佳方法是什么?

1 个答案:

答案 0 :(得分:2)

有两种方法。

首先,您可以在csv中修改数据,删除无法转换为'float'的字符串。要使用tf.estimator Quickstart中的演示代码,您应该保留csv格式,例如iris_training.csviris_test.csv

第二种方法,您可以修改您调用的函数load_csv_without_header的代码。原始代码如下:

def load_csv_without_header(filename,
                        target_dtype,
                        features_dtype,
                        target_column=-1):
  """Load dataset from CSV file without a header row."""
  with gfile.Open(filename) as csv_file:
    data_file = csv.reader(csv_file)
    data, target = [], []
    for row in data_file:
      target.append(row.pop(target_column))
      data.append(np.asarray(row, dtype=features_dtype))

    target = np.array(target, dtype=target_dtype)
    data = np.array(data)
    return Dataset(data=data, target=target)

这里,它使用了一些常见的模块,如csv,numpy,集合,python的功能,如next,enumerate,tensorflow中的函数,如gfile。您可以调试此代码,然后修改数据的代码。

此外,您可以使用tf.decode_csv

最后,欢迎来到张量流。