使用TensorFlow 2.0训练模型时,我的手变得肮脏。 tf.data
API中的新迭代功能非常出色。但是,当我执行以下代码时,我发现与torch.utils.data.DataLoader
中的迭代功能不同,它不会在每个时期自动对数据进行随机播放。如何使用TF2.0实现该目标?
import numpy as np
import tensorflow as tf
def sample_data():
...
data = sample_data()
NUM_EPOCHS = 10
BATCH_SIZE = 128
# Subsample the data
mask = range(int(data.shape[0]*0.8), data.shape[0])
data_val = data[mask]
mask = range(int(data.shape[0]*0.8))
data_train = data[mask]
train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
shuffle(buffer_size=10000).\
repeat(1).batch(BATCH_SIZE)
val_dset = tf.data.Dataset.from_tensor_slices(data_val).\
batch(BATCH_SIZE)
loss_metric = tf.keras.metrics.Mean(name='train_loss')
optimizer = tf.keras.optimizers.Adam(0.001)
@tf.function
def train_step(inputs):
...
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
for inputs in train_dset:
train_step(inputs)
...
答案 0 :(得分:1)
该批次需要重新组合:
train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
repeat(1).batch(BATCH_SIZE)
train_dset = train_dset.shuffle(buffer_size=buffer_size)