尝试将CS​​V数据集读入TensorFlow

时间:2018-08-09 01:11:49

标签: python tensorflow

我有csv文件(训练和测试),其中包含以下格式的数据

enter image description here

我正在尝试将我的数据(100MB)加载到tensorflow输入管道中(不确定这样做是否正确)。

  x, y = tf.placeholder(tf.float32, shape=[None,2]), tf.placeholder(tf.float32, shape=[None,1])
 dataset = tf.data.Dataset.from_tensor_slices((x, y))

train_dataset = tf.contrib.data.make_csv_dataset(csv_path+'\csvdatatrain.csv', batch_size=32)
test_dataset = tf.contrib.data.make_csv_dataset(csv_path+'\csvdatatest.csv', batch_size=32)


iter = dataset.make_initializable_iterator()
features, labels = iter.get_next()
with tf.Session() as sess:
    #initialise iterator with train data
    sess.run(iter.initializer, feed_dict={x: train_dataset[0], y: train_dataset[1]})

    for _ in range(EPOCHS):
        #### Training
        print('This is training')

        #### Testing
        sess.run(iter.initializer, feed_dict={ x: test_dataset[0], y: test_dataset[1]})
        print('This is testing')

使用火车数据初始化迭代器时出现错误

  

sess.run(iter.initializer,feed_dict = {x:train_dataset 1,y:train_dataset [0]})   TypeError:“ PrefetchDataset”对象不支持索引

1 个答案:

答案 0 :(得分:1)

如果要使用tensorflow数据加载器,则不需要占位符。您可以使用numpy加载csv文件,然后通过feed_dict将其传递到占位符,或者使用tensorflow数据加载器,仅将路径传递到csv文件。

如果您要使用tensorflow数据加载器,则无需在每个时期都初始化迭代器!它只初始化一次,不需要feed_dict

此代码在tensorflow 1.8(在那里使用iris dataset)上与我一起工作

EPOCHS = 10
dataset = tf.contrib.data.make_csv_dataset('iris.csv',batch_size=1)
iterator = dataset.make_initializable_iterator()
next_elemnet = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    for _ in range(EPOCHS):
        print(sess.run(next_elemnet))

输出应如下所示:

{'sepal_length': array([ 6.4000001], dtype=float32), 'sepal_width': array([ 2.79999995], dtype=float32), 'petal_length': array([ 5.5999999], dtype=float32), 'petal_width': array([ 2.20000005], dtype=float32), 'species': array([b'virginica'], dtype=object)}
{'sepal_length': array([ 5.19999981], dtype=float32), 'sepal_width': array([ 2.70000005], dtype=float32), 'petal_length': array([ 3.9000001], dtype=float32), 'petal_width': array([ 1.39999998], dtype=float32), 'species': array([b'versicolor'], dtype=object)}
{'sepal_length': array([ 4.80000019], dtype=float32), 'sepal_width': array([ 3.4000001], dtype=float32), 'petal_length': array([ 1.89999998], dtype=float32), 'petal_width': array([ 0.2], dtype=float32), 'species': array([b'setosa'], dtype=object)}
{'sepal_length': array([ 7.69999981], dtype=float32), 'sepal_width': array([ 2.79999995], dtype=float32), 'petal_length': array([ 6.69999981], dtype=float32), 'petal_width': array([ 2.], dtype=float32), 'species': array([b'virginica'], dtype=object)}
{'sepal_length': array([ 4.4000001], dtype=float32), 'sepal_width': array([ 3.], dtype=float32), 'petal_length': array([ 1.29999995], dtype=float32), 'petal_width': array([ 0.2], dtype=float32), 'species': array([b'setosa'], dtype=object)}
{'sepal_length': array([ 5.], dtype=float32), 'sepal_width': array([ 3.], dtype=float32), 'petal_length': array([ 1.60000002], dtype=float32), 'petal_width': array([ 0.2], dtype=float32), 'species': array([b'setosa'], dtype=object)}
{'sepal_length': array([ 6.4000001], dtype=float32), 'sepal_width': array([ 3.20000005], dtype=float32), 'petal_length': array([ 5.30000019], dtype=float32), 'petal_width': array([ 2.29999995], dtype=float32), 'species': array([b'virginica'], dtype=object)}
{'sepal_length': array([ 4.5], dtype=float32), 'sepal_width': array([ 2.29999995], dtype=float32), 'petal_length': array([ 1.29999995], dtype=float32), 'petal_width': array([ 0.30000001], dtype=float32), 'species': array([b'setosa'], dtype=object)}
{'sepal_length': array([ 5.], dtype=float32), 'sepal_width': array([ 3.5], dtype=float32), 'petal_length': array([ 1.60000002], dtype=float32), 'petal_width': array([ 0.60000002], dtype=float32), 'species': array([b'setosa'], dtype=object)}
{'sepal_length': array([ 6.], dtype=float32), 'sepal_width': array([ 2.20000005], dtype=float32), 'petal_length': array([ 5.], dtype=float32), 'petal_width': array([ 1.5], dtype=float32), 'species': array([b'virginica'], dtype=object)}