将tf.dataste迭代器传递给keras中的回调以进行混淆矩阵计算

时间:2018-11-06 03:48:15

标签: keras

我的数据存储在tfrecords中,所以我直接将tf.dataset api与model.fit一起使用。即

def data_preparing():
    dataset_train = tf.data.TFRecordDataset(training_files, num_parallel_reads=calls)
    dataset_train = dataset_train.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1000+ 4 * batch_size))
    dataset_train = dataset_train.map(decode_train, num_parallel_calls= calls)  # , num_parallel_calls = 12
    dataset_train = dataset_train.batch(batch_size)
    dataset_train = dataset_train.prefetch(tf.contrib.data.AUTOTUNE)

    dataset_val = tf.data.TFRecordDataset(val_filenames, num_parallel_reads=calls)
    dataset_val = dataset_val.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=4 * batch_size))
    dataset_val = dataset_val.map(decode_val, num_parallel_calls=calls)  # , num_parallel_calls = 12
    dataset_val = dataset_val.batch(batch_size)
    dataset_val = dataset_val.prefetch(tf.contrib.data.AUTOTUNE)

return dataset_train, dataset_val 

dataset_train, dataset_val  = data_preparing()

model.fit( dataset_train, epochs = epochs, verbose = 1,
                    steps_per_epoch=int(np.ceil(total_train/batch_size)),
               validation_data=dataset_val,
               validation_steps=int(np.ceil(total_val / batch_size)),
               callbacks=call_backs)

所以无论如何,我可以将train_dataval_data(现在是它们的张量)传递给tf.keras.callback.Callback来计算用于训练和训练的混淆矩阵每个时代的验证数据?

我可以使用fit_generator来解决它,但是我不想使用fit_generator,因为它与model.fit相比太慢了。

0 个答案:

没有答案