我可以从csv获取示例代码来读取数据。 我的要求是我需要在TensorFlow中从CSV生成训练和测试数据。
一个包含Train和Test数据的CSV。我的意思是前10行我用于火车,接下来10行用于测试 在此先感谢
答案 0 :(得分:2)
TensorFlow的人们已经创建了excellent tutorial来做到这一点。它介绍了如何从csv中读取人口普查数据,将其转换为张量,并使用高级估算器API来拟合和评估机器学习模型。
但是,当我尝试使用urllib
函数时,我确实收到了错误,并且我稍微修改了代码,以便使用pandas
直接读取数据。
原始代码
import tempfile
import urllib
train_file = tempfile.NamedTemporaryFile()
test_file = tempfile.NamedTemporaryFile()
urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name)
urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name)
import pandas as pd
CSV_COLUMNS = [
"age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender",
"capital_gain", "capital_loss", "hours_per_week", "native_country",
"income_bracket"]
df_train = pd.read_csv(train_file.name, names=CSV_COLUMNS, skipinitialspace=True)
df_test = pd.read_csv(test_file.name, names=CSV_COLUMNS, skipinitialspace=True, skiprows=1)
修改后的代码
import pandas as pd
COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender",
"capital_gain", "capital_loss", "hours_per_week", "native_country",
"income_bracket"]
df_train = pd.read_csv('http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data'
, names=COLUMNS
, skipinitialspace=True)
df_test = pd.read_csv('http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test'
, names=COLUMNS
, skipinitialspace=True
, skiprows=1)