keras.utils.Sequence有多个文件

时间:2019-04-29 23:48:56

标签: tensorflow keras

我了解如何将keras.utils.Sequence与一个数据文件一起使用。您可以对keras.utils.Sequence类进行子类化并实现其接口:__len____getitem__

例如:

def __len__(self):
    "Denotes the number of batches per epoch"
    return int(np.ceil(self.no_examples / float(self.batch_size)))

def __getitem__(self, idx):
    #build the batch w/ idx and self.batch_size

但是,如果您的数据分散在多个文件中怎么办?例如:

  • train_part1.csv
  • train_part2.csv
  • train_partn.csv

如何仅用一个指针idx遍历所有批次?

1 个答案:

答案 0 :(得分:1)

您可以设置(range,file_path)的映射

def __init__(self, file_paths, batch_size):
    self.batch_size = batch_size
    self._mapping = dict()
    count = 0
    for file_path in file_paths:
        with open(file_path, 'r') as f:
            size = len(f.readlines())
            self._mapping[(count, count+size)] = file_path
        count += size
    self.no_examples = count

def _find_file_path(self, idx):
    for range, file_path in self._mapping.items():
        start, end = range[0], range[1]
        if start <= idx and idx <= end:
            in_file_idx = idx - start
            return (in_file_idx, file_path)

def __len__(self):
    "Denotes the number of batches per epoch"
    return int(np.ceil(self.no_examples / float(self.batch_size)))

@functools.lru_cache(maxsize=128)  # add memoize for file caching
def _read_file_data(self, file_path):
    with open(file_path, 'r') as f:
        return list(f.readlines())

def __getitem__(self, idx):
    in_file_idx, file_path = self._find_file_path(idx)
    lines = self._read_file_data(file_path)
    return lines[in_file_idx]

进一步的优化:

  • 统计内存消耗,由于内存大小,删除已记忆的文件内容(如果文件太大而无法容纳内存);
  • 如果有大量文件,则实现效率更高的_find_file_path,当前实现为O(n);