如何将数据读入Tensorflow?

时间:2016-02-26 05:36:19

标签: python mongodb csv tensorflow

我试图将CSV文件中的数据读取到tensorflow,

https://www.tensorflow.org/versions/r0.7/how_tos/reading_data/index.html#filenames-shuffling-and-epoch-limits

官方文档中的示例代码如下:

col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)

要读取文件,我需要预先知道文件中有多少列和行,如果有1000列,我需要定义1000个变量,如col1, col2, col3, col4, col5,..., col1000 ,这看起来不像有效的数据读取方式。

我的问题

  1. 将CSV文件读入Tensorflow的最佳方法是什么?

  2. 有没有办法在Tensorflow中读取数据库(例如mongoDB)?

4 个答案:

答案 0 :(得分:5)

  1. 你绝对不需要将col1,col2定义为col1000 ......

    通常,您可能会这样做:

    
    columns = tf.decode_csv(value, record_defaults=record_defaults)
    features = tf.pack(columns)
    do_whatever_you_want_to_play_with_features(features)
    
  2. 我不知道从MongoDB直接读取数据的任何现成方法。也许您可以编写一个简短的脚本来以Tensorflow支持的格式转换MongoDB中的数据,我建议使用二进制格式TFRecord,这比csv记录快得多。 This是关于此主题的好文章。或者您可以选择自己实施自定义数据阅读器,请参阅此处的the official doc

答案 1 :(得分:2)

def func()
    return 1,2,3,4

b = func() 

print b #(1, 2, 3, 4)

print [num for num in b] # [1, 2, 3, 4]

嗨它与tensorflow无关,它的简单python不需要定义1000变量。 tf.decode_csv返回一个元组。

不知道数据库处理,我想你可以使用python,只需将数组形式的数据输入到tensorflow。

希望这是有帮助的

答案 2 :(得分:1)

当然你可以实现从mongo直接读取批量随机排序训练数据以提供给tensorflow。以下是我的方式:

        for step in range(self.steps):


            pageNum=1;
            while(True):
                trainArray,trainLabelsArray = loadBatchTrainDataFromMongo(****)
                if len(trainArray)==0:
                    logging.info("train datas consume up!")
                    break;
                logging.info("started to train")
                sess.run([model.train_op],
                         feed_dict={self.input: trainArray,
                                    self.output: np.asarray(trainLabelsArray),
                                    self.keep_prob: params['dropout_rate']})

                pageNum=pageNum+1;

并且您还需要在mongodb中预处理经过训练的数据,例如:在mongodb中为每个训练数据分配一个随机排序值......

答案 3 :(得分:0)

  

有什么方法可以在Tensorflow中读取数据库(例如mongoDB)吗?

尝试TFMongoDB,这是TensorFlow的C ++实现的数据集操作,可让您连接到MongoDB:

pip install tfmongodb

在GitHub页面上有一个有关如何读取数据的示例。另请参见pypi: TFMongoDB