我的数据存储在tfrecords中,所以我直接将tf.dataset api与model.fit一起使用。即
def data_preparing():
dataset_train = tf.data.TFRecordDataset(training_files, num_parallel_reads=calls)
dataset_train = dataset_train.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1000+ 4 * batch_size))
dataset_train = dataset_train.map(decode_train, num_parallel_calls= calls) # , num_parallel_calls = 12
dataset_train = dataset_train.batch(batch_size)
dataset_train = dataset_train.prefetch(tf.contrib.data.AUTOTUNE)
dataset_val = tf.data.TFRecordDataset(val_filenames, num_parallel_reads=calls)
dataset_val = dataset_val.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=4 * batch_size))
dataset_val = dataset_val.map(decode_val, num_parallel_calls=calls) # , num_parallel_calls = 12
dataset_val = dataset_val.batch(batch_size)
dataset_val = dataset_val.prefetch(tf.contrib.data.AUTOTUNE)
return dataset_train, dataset_val
dataset_train, dataset_val = data_preparing()
model.fit( dataset_train, epochs = epochs, verbose = 1,
steps_per_epoch=int(np.ceil(total_train/batch_size)),
validation_data=dataset_val,
validation_steps=int(np.ceil(total_val / batch_size)),
callbacks=call_backs)
所以无论如何,我可以将train_data
和val_data
(现在是它们的张量)传递给tf.keras.callback.Callback
来计算用于训练和训练的混淆矩阵每个时代的验证数据?
我可以使用fit_generator
来解决它,但是我不想使用fit_generator
,因为它与model.fit
相比太慢了。