用于生成model.fit()输入而不是Collection的生成器? (训练数据太大而无法存储)

时间:2018-06-28 14:20:59

标签: numpy tensorflow memory tflearn

tf学习模型的拟合函数可以像这样通过训练和测试数据:

model = tflearn.DNN(nn)

model.fit({'input': X_train},
          {'targets': Y_train},
          n_epoch=10,
          validation_set=(
              {'input': X_test},
              {'targets': Y_test}
          ))

其中nn是模型的定义。但是,如果X_train这样的集合对于内存来说太大了怎么办?

在我的情况下,我以索引列表的形式压缩了稀疏的二进制矢量(单元为1或0),并对矢量的维数进行了整数编码,这使我可以重构原始矢量。

压缩向量的集合确实适合内存,但不是包含完整向量的集合。因此,我尝试改为在X_train和其他集合(现在包含压缩矢量)上传递生成器,以便即时生成完整矢量,但是model.fit需要一个len()函数。因此,我定义了一个自定义Feeder类,如下所示:

class Feeder:
    def __init__(self, data, convert):
        self.data = data
        self.convert = convert

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return self

    def __next__(self):
        for item in self.data:
            yield self.convert(item)

我这样称呼:

def reconstruct_vector(non_zero_indices, dimensionality):
    """
    returns a vector of zeros and ones reconstructed from a sparse vector
    and a dimensions value
    """
    vec = np.zeros(dimensionality)
    for i in non_zero_intraining data too large for memorydices:
        vec[i] = 1
    return vec

item_to_input_vector = lambda item : reconstruct_vector(item[0], item[1])
item_to_target_vector = lambda item : np.array([1,0]) if item else np.array([0,1])

model.fit({'input': Feeder(X_train, item_to_input_vector)},
          {'targets': Feeder(Y_train, item_to_target_vector)},
          n_epoch=10,
          validation_set=(
              {'input': Feeder(X_test, item_to_input_vector)},
              {'targets': Feeder(Y_test, item_to_target_vector)}
          ))

但是这也行不通,因为我得到了一些神秘的错误:

Exception in thread Thread-3:
Traceback (most recent call last):
  File "/usr/lib64/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/usr/lib64/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/site-packages/tflearn/data_flow.py", line 187, in fill_feed_dict_queue
    data = self.retrieve_data(batch_ids)
  File "/usr/lib/python3.6/site-packages/tflearn/data_flow.py", line 222, in retrieve_data
    utils.slice_array(self.feed_dict[key], batch_ids)
  File "/usr/lib/python3.6/site-packages/tflearn/utils.py", line 187, in slice_array
    return X[start]
TypeError: only integer scalar arrays can be converted to a scalar index

那么,解决这个问题的正确方法是什么?

0 个答案:

没有答案