来自csv文件的自定义形状数据集 - tensorflow

时间:2018-06-19 10:01:24

标签: python tensorflow

我有一个包含我的数据的csv文件,每行是一个[1, 2, 3, 4]形式的向量,其中1,2,3是特征,4是标签。我想阅读我的csv并从中提取出一小部分形状为(32, 20, 3)的张量。目前我已经制作了这段代码:

import tensorflow as tf

FIELD_DEFAULTS = [[0], [0], [0], [0]]


def _parse_line(line):
    # Decode the line into its fields
    parsed_line = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    # features = dict(zip(COLUMNS, fields))
    features = tf.reshape(parsed_line[:-1], [3])

    # Separate the label from the features
    # label = features.pop('label')
    label = tf.reshape(parsed_line[-1], [1])

    return features, label


ds = tf.data.TextLineDataset('data.csv').skip(1)
ds = ds.map(_parse_line)
ds = ds.batch(20)
iteratore = ds.make_initializable_iterator()
f1, l1 = iteratore.get_next()
num_epochs = 1

with tf.Session() as sess:
    for _ in range(num_epochs):
        sess.run(iteratore.initializer)
        try:
            while True:
                pino, antonio = sess.run([f1, l1])
                print(pino.shape) 

        except tf.errors.OutOfRangeError:
            pass

输出是:

(20, 3)
(20, 3)
(20, 3)
(20, 3)
(20, 3)

.......
.......
.......

(20, 3)
(20, 3)
(20, 3)
(20, 3)
(20, 3)
(20, 3)
(20, 3)
(1, 3)

根据我的理解,我有一个形状的张量列表(20,3),除了最后一个是(1,3)。如何从这批张量的(32, 20, 3)中提取?我现在对标签张量不感兴趣。

由于

0 个答案:

没有答案