如何在TensorFlow

时间:2017-08-18 20:41:17

标签: python tensorflow

我可以从csv获取示例代码来读取数据。 我的要求是我需要在TensorFlow中从CSV生成训练和测试数据。

一个包含Train和Test数据的CSV。我的意思是前10行我用于火车,接下来10行用于测试 在此先感谢

1 个答案:

答案 0 :(得分:2)

TensorFlow的人们已经创建了excellent tutorial来做到这一点。它介绍了如何从csv中读取人口普查数据,将其转换为张量,并使用高级估算器API来拟合和评估机器学习模型。

但是,当我尝试使用urllib函数时,我确实收到了错误,并且我稍微修改了代码,以便使用pandas直接读取数据。

原始代码

import tempfile
import urllib
train_file = tempfile.NamedTemporaryFile()
test_file = tempfile.NamedTemporaryFile()
urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_file.name)
urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_file.name)

import pandas as pd
CSV_COLUMNS = [
    "age", "workclass", "fnlwgt", "education", "education_num",
    "marital_status", "occupation", "relationship", "race", "gender",
    "capital_gain", "capital_loss", "hours_per_week", "native_country",
    "income_bracket"]
df_train = pd.read_csv(train_file.name, names=CSV_COLUMNS, skipinitialspace=True)
df_test = pd.read_csv(test_file.name, names=CSV_COLUMNS, skipinitialspace=True, skiprows=1)

修改后的代码

import pandas as pd
COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
           "marital_status", "occupation", "relationship", "race", "gender",
           "capital_gain", "capital_loss", "hours_per_week", "native_country",
           "income_bracket"]

df_train = pd.read_csv('http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data'
                       , names=COLUMNS
                       , skipinitialspace=True)
df_test = pd.read_csv('http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test'
                      , names=COLUMNS
                      , skipinitialspace=True
                      , skiprows=1)