自定义迭代器无法工作的示例

时间:2017-08-16 17:21:59

标签: python iterator mxnet

我按照此处所述创建自定义迭代器的说明和示例:http://mxnet.io/tutorials/basic/data.html

以下代码产生ValueError:

sudo kextunload -b com.intel.kext.intelhaxm
  

ValueError:标签0的形状与预测1的形状不匹配

我的问题:

  • 有人可以重现这个问题吗?
  • 有人知道解决方案吗?

我在Mac上使用jupyter,新安装了所有东西,包括python ...... 我还使用以下方法直接测试了python:

mod.fit(data_iter, num_epoch=5)

代码

Python 3.6.1 |Anaconda custom (x86_64)| (default, May 11 2017, 13:04:09) 
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)] on darwin

错误

import mxnet as mx
import os
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


import numpy as np
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
    print([batch.data, batch.label, batch.pad])

#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 20]


#lets save `data` into a csv file first and try reading it back
np.savetxt('data.csv', data, delimiter=',')
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
for batch in data_iter:
    print([batch.data, batch.pad])

#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 20]

class SimpleIter(mx.io.DataIter):
    def __init__(self, data_names, data_shapes, data_gen,
                 label_names, label_shapes, label_gen, num_batches=10):
        self._provide_data = zip(data_names, data_shapes)
        self._provide_label = zip(label_names, label_shapes)
        self.num_batches = num_batches
        self.data_gen = data_gen
        self.label_gen = label_gen
        self.cur_batch = 0

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch < self.num_batches:
            self.cur_batch += 1
            data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data,\
                                                        self.data_gen)]
            label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label,\
                                                        self.label_gen)]
            return mx.io.DataBatch(data, label)
        else:
            raise StopIteration


import mxnet as mx
num_classes = 10
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')
print(net.list_arguments())
print(net.list_outputs())

#['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias', 'softmax_label']
#['softmax_output']



import logging
logging.basicConfig(level=logging.INFO)

n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
                  [lambda s: np.random.uniform(-1, 1, s)],
                  ['softmax_label'], [(n,)],
                  [lambda s: np.random.randint(0, num_classes, s)])

mod = mx.mod.Module(symbol=net)
mod.fit(data_iter, num_epoch=5)

1 个答案:

答案 0 :(得分:0)

这是一个python版本的地狱问题。我能够使用python 2.7上的所有工作和编译。 python 3.x版本似乎创建了问题,错误消息并没有真正帮助...