我之前测试过的CNN模型。但是,我是新手导入自己的数据,并且我一直在收到错误。任何人都可以告诉我我做错了什么,以及导入数据的正确方法是什么,以便模型可以运行?我可以使用任何资源,例如在线书籍或指南吗?
这是我目前的代码;
code
https://pastebin.com/wKtidYGL
测试csv是(624,12000) 火车csv是(624,362) 测试和训练标签也是(11999,1)和(361,1)
答案 0 :(得分:0)
Tensorflow开发人员已经制作了许多有用的CSV处理功能,以便能够轻松地将其提供给您的培训模型。
我建议您阅读此官方documentation page的CSV部分。
import iris_data
train_path, test_path = iris_data.maybe_download()
ds = tf.data.TextLineDataset(train_path).skip(1)
# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
'PetalLength', 'PetalWidth',
'label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
# Decode the line into its fields
fields = tf.decode_csv(line, FIELD_DEFAULTS)
# Pack the result into a dictionary
features = dict(zip(COLUMNS,fields))
# Separate the label from the features
label = features.pop('label')
return features, label
ds = ds.map(_parse_line)
print(ds)