如何在TensorFlow 2.0中使用tf.data API在每个纪元随机播放数据?

时间:2019-04-25 08:31:20

标签: python-3.x tensorflow2.0

使用TensorFlow 2.0训练模型时,我的手变得肮脏。 tf.data API中的新迭代功能非常出色。但是,当我执行以下代码时,我发现与torch.utils.data.DataLoader中的迭代功能不同,它不会在每个时期自动对数据进行随机播放。如何使用TF2.0实现该目标?

import numpy as np
import tensorflow as tf
def sample_data():
    ...

data = sample_data()

NUM_EPOCHS = 10
BATCH_SIZE = 128

# Subsample the data
mask = range(int(data.shape[0]*0.8), data.shape[0])
data_val = data[mask]
mask = range(int(data.shape[0]*0.8))
data_train = data[mask]

train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
                                 shuffle(buffer_size=10000).\
                                repeat(1).batch(BATCH_SIZE)
val_dset = tf.data.Dataset.from_tensor_slices(data_val).\
                                 batch(BATCH_SIZE)


loss_metric = tf.keras.metrics.Mean(name='train_loss')
optimizer = tf.keras.optimizers.Adam(0.001)

@tf.function
def train_step(inputs):
    ...

for epoch in range(NUM_EPOCHS):
    # Reset the metrics
    loss_metric.reset_states()
    for inputs in train_dset:
        train_step(inputs)
    ...

1 个答案:

答案 0 :(得分:1)

该批次需要重新组合:

train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
                                repeat(1).batch(BATCH_SIZE)

train_dset = train_dset.shuffle(buffer_size=buffer_size)