输入csv与他们一起训练

时间:2018-05-05 09:58:32

标签: tensorflow

我之前测试过的CNN模型。但是,我是新手导入自己的数据,并且我一直在收到错误。任何人都可以告诉我我做错了什么,以及导入数据的正确方法是什么,以便模型可以运行?我可以使用任何资源,例如在线书籍或指南吗?

这是我目前的代码;

code

https://pastebin.com/wKtidYGL

测试csv是(624,12000) 火车csv是(624,362) 测试和训练标签也是(11999,1)和(361,1)

1 个答案:

答案 0 :(得分:0)

Tensorflow开发人员已经制作了许多有用的CSV处理功能,以便能够轻松地将其提供给您的培训模型。

我建议您阅读此官方documentation page的CSV部分。

import iris_data
train_path, test_path = iris_data.maybe_download()

ds = tf.data.TextLineDataset(train_path).skip(1)

# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
       'PetalLength', 'PetalWidth',
       'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]

def _parse_line(line):
   # Decode the line into its fields
   fields = tf.decode_csv(line, FIELD_DEFAULTS)

   # Pack the result into a dictionary
   features = dict(zip(COLUMNS,fields))

   # Separate the label from the features
   label = features.pop('label')

   return features, label

   ds = ds.map(_parse_line)
   print(ds)