我了解如何将keras.utils.Sequence
与一个数据文件一起使用。您可以对keras.utils.Sequence
类进行子类化并实现其接口:__len__
和__getitem__
。
例如:
def __len__(self):
"Denotes the number of batches per epoch"
return int(np.ceil(self.no_examples / float(self.batch_size)))
def __getitem__(self, idx):
#build the batch w/ idx and self.batch_size
但是,如果您的数据分散在多个文件中怎么办?例如:
如何仅用一个指针idx
遍历所有批次?
答案 0 :(得分:1)
您可以设置(range,file_path)的映射
def __init__(self, file_paths, batch_size):
self.batch_size = batch_size
self._mapping = dict()
count = 0
for file_path in file_paths:
with open(file_path, 'r') as f:
size = len(f.readlines())
self._mapping[(count, count+size)] = file_path
count += size
self.no_examples = count
def _find_file_path(self, idx):
for range, file_path in self._mapping.items():
start, end = range[0], range[1]
if start <= idx and idx <= end:
in_file_idx = idx - start
return (in_file_idx, file_path)
def __len__(self):
"Denotes the number of batches per epoch"
return int(np.ceil(self.no_examples / float(self.batch_size)))
@functools.lru_cache(maxsize=128) # add memoize for file caching
def _read_file_data(self, file_path):
with open(file_path, 'r') as f:
return list(f.readlines())
def __getitem__(self, idx):
in_file_idx, file_path = self._find_file_path(idx)
lines = self._read_file_data(file_path)
return lines[in_file_idx]
进一步的优化: