带有tfrecords或numpy的tf.keras中的数据管道

时间:2019-04-02 14:37:49

标签: tensorflow tensorflow-datasets tf.keras tensorflow2.0

我想在Tensorflow 2.0的tf.keras中使用大于我的ram的数据训练模型,但是这些教程仅显示具有预定义数据集的示例。

我遵循了本教程:

Load Images with tf.data,对于numpy数组或tfrecords上的数据,我无法做到这一点。

这是一个将数组转换为张量流数据集的示例。我想要的是使它适用于多个numpy数组文件或多个tfrecords文件。

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

# Since the dataset already takes care of batching,
# we don't pass a `batch_size` argument.
model.fit(train_dataset, epochs=3)

1 个答案:

答案 0 :(得分:0)

如果您有tfrecords个文件:

path = ['file1.tfrecords', 'file2.tfrecords', ..., 'fileN.tfrecords']
dataset = tf.data.Dataset.list_files(path, shuffle=True).repeat()
dataset = dataset.interleave(lambda filename: tf.data.TFRecordDataset(filename), cycle_length=len(path))
dataset = dataset.map(parse_function).batch()

parse_function处理解码和任何形式的扩充。

对于numpy数组,您可以从文件名列表或数组列表构造数据集。标签只是一个列表。或者可以在解析单个示例时从文件中获取它们。

path = #list of numpy arrays

path = os.listdir(path_to files)

dataset = tf.data.Dataset.from_tensor_slices((path, labels))
dataset = dataset.map(parse_function).batch()

parse_function处理解码:

def parse_function(filename, label):  #Both filename and label will be passed if you provided both to from_tensor_slices
    f = tf.read_file(filename)
    image = tf.image.decode_image(f)) 
    image = tf.reshape(image, [H, W, C])
    label = label #or it could be extracted from, for example, filename, or from file itself 
    #do any augmentations here
    return image, label

要解码.npy文件,最好的方法是不使用reshaperead_file的情况下使用decode_raw,但首先使用np.load加载numpy:

paths = [np.load(i) for i in ["x1.npy", "x2.npy"]]
image = tf.reshape(filename, [2])

或尝试使用decode_raw

f = tf.io.read_file(filename)
image = tf.io.decode_raw(f, tf.float32)

然后将批处理的数据集传递到model.fit(dataset)。 TensorFlow 2.0允许对数据集进行简单迭代。无需使用迭代器。即使在更高版本的1.x API中,您也可以将数据集传递给.fit方法

for example in dataset:
    func(example)