tf学习模型的拟合函数可以像这样通过训练和测试数据:
model = tflearn.DNN(nn)
model.fit({'input': X_train},
{'targets': Y_train},
n_epoch=10,
validation_set=(
{'input': X_test},
{'targets': Y_test}
))
其中nn
是模型的定义。但是,如果X_train
这样的集合对于内存来说太大了怎么办?
在我的情况下,我以索引列表的形式压缩了稀疏的二进制矢量(单元为1或0),并对矢量的维数进行了整数编码,这使我可以重构原始矢量。
压缩向量的集合确实适合内存,但不是包含完整向量的集合。因此,我尝试改为在X_train
和其他集合(现在包含压缩矢量)上传递生成器,以便即时生成完整矢量,但是model.fit
需要一个len()
函数。因此,我定义了一个自定义Feeder类,如下所示:
class Feeder:
def __init__(self, data, convert):
self.data = data
self.convert = convert
def __len__(self):
return len(self.data)
def __iter__(self):
return self
def __next__(self):
for item in self.data:
yield self.convert(item)
我这样称呼:
def reconstruct_vector(non_zero_indices, dimensionality):
"""
returns a vector of zeros and ones reconstructed from a sparse vector
and a dimensions value
"""
vec = np.zeros(dimensionality)
for i in non_zero_intraining data too large for memorydices:
vec[i] = 1
return vec
item_to_input_vector = lambda item : reconstruct_vector(item[0], item[1])
item_to_target_vector = lambda item : np.array([1,0]) if item else np.array([0,1])
model.fit({'input': Feeder(X_train, item_to_input_vector)},
{'targets': Feeder(Y_train, item_to_target_vector)},
n_epoch=10,
validation_set=(
{'input': Feeder(X_test, item_to_input_vector)},
{'targets': Feeder(Y_test, item_to_target_vector)}
))
但是这也行不通,因为我得到了一些神秘的错误:
Exception in thread Thread-3:
Traceback (most recent call last):
File "/usr/lib64/python3.6/threading.py", line 916, in _bootstrap_inner
self.run()
File "/usr/lib64/python3.6/threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.6/site-packages/tflearn/data_flow.py", line 187, in fill_feed_dict_queue
data = self.retrieve_data(batch_ids)
File "/usr/lib/python3.6/site-packages/tflearn/data_flow.py", line 222, in retrieve_data
utils.slice_array(self.feed_dict[key], batch_ids)
File "/usr/lib/python3.6/site-packages/tflearn/utils.py", line 187, in slice_array
return X[start]
TypeError: only integer scalar arrays can be converted to a scalar index
那么,解决这个问题的正确方法是什么?