创建具有匹配元素的元组数组 - mnist data

时间:2017-11-20 16:43:49

标签: python numpy

我有一个csv数据,数据的第一列是'label',第一列到第784列之后的列包含图像(28 * 28)格式的表示。

我正在尝试创建这两个数组。我创建它但我喜欢的格式没有出现。

这是我使用的代码:

import csv
import numpy as np
import pandas as pd

with open(dir_path+'train0.csv', 'rU') as csv_file:
    for df in csv.reader(csv_file):
        label=np.array(df[0], dtype=float)
        pixels=np.array(df[1:], dtype='float').reshape((28,28))
        print zip((label, pixels))

结果:

[(array(0.0),), (array([[   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
           0.]]),)]
  

但是我想要的格式是:

(array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]], dtype=float32), array([5, 0, 4, ..., 8, 4, 8]))

请注意,我将标签和像素值从我的问题切换到所需的输出。这基本上是两个条目的元组数组。

这是因为我使用的是csv。我不能修复它。任何帮助将不胜感激

  

这是我最终得到的解决方案:

filename=dir_path+'train1.csv'

def load(filename):
    # read file into a list of rows
    with open(filename, 'rU') as csvfile:
        lines = csv.reader(csvfile, delimiter=',')
        rows = list(lines)

    # create empty numpy arrays of the required size
    data = np.empty((len(rows), len(rows[0])-1), dtype=np.float64)
    expected = np.empty((len(rows),), dtype=np.int64)

    # fill array with data from the csv-rows
    for i, row in enumerate(rows):
        data[i,:] = row[1:]
        expected[i] = row[0]

    training_data = data, expected
    return training_data

print load(filename)
  

结果:

(array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       ..., 
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]]), array([1, 1, 1, ..., 1, 1, 1]))

参考:stackoverflow.com/search?q=formatting+my+mnist+tuple

2 个答案:

答案 0 :(得分:0)

使用示例'csv'文字:

In [41]: txt = b'''label1 1 2 3 4
    ...: label2 8 9 10 11
    ...: label3 10 11 12 13
    ...: '''

和化合物dtype:

In [46]: dt = np.dtype([('label','U10'),('image',float,(2,2))])

genfromtxt可以将列加载为标签和3d图像字段:

In [47]: data = np.genfromtxt(txt.splitlines(), dtype=dt)
In [48]: data
Out[48]: 
array([('label1', [[  1.,   2.], [  3.,   4.]]),
       ('label2', [[  8.,   9.], [ 10.,  11.]]),
       ('label3', [[ 10.,  11.], [ 12.,  13.]])],
      dtype=[('label', '<U10'), ('image', '<f8', (2, 2))])
In [49]: data['image']
Out[49]: 
array([[[  1.,   2.],
        [  3.,   4.]],

       [[  8.,   9.],
        [ 10.,  11.]],

       [[ 10.,  11.],
        [ 12.,  13.]]])

您可以改变dtype以满足您的需求。

dt = np.dtype([('label','U10'),('image',float,(4,))])

答案 1 :(得分:0)

这是解决方案

filename=dir_path+'train1.csv'

    def load(filename):
        # read file into a list of rows
        with open(filename, 'rU') as csvfile:
            lines = csv.reader(csvfile, delimiter=',')
            rows = list(lines)

        # create empty numpy arrays of the required size
        data = np.empty((len(rows), len(rows[0])-1), dtype=np.float64)
        expected = np.empty((len(rows),), dtype=np.int64)

        # fill array with data from the csv-rows
        for i, row in enumerate(rows):
            data[i,:] = row[1:]
            expected[i] = row[0]

        training_data = data, expected
        return training_data

    print load(filename)