我正在训练一个有4个班级的分类模型。当我将数据生成器与numpy
数组一起使用时,模型可以很好地训练,损失减少,并且预测效果很好。
但是,当我针对完全相同数量的具有相同参数的历时训练相同模型时,训练损失的确减少到几乎相同的值,但是即使在训练数据集上,预测精度也低于50%。
我浏览了以下链接:
Keras model failed to learn anything after changing to use tf.data api
https://github.com/tensorflow/tensorflow/issues/22190)
我对损失函数进行了调整(从categorical_crossentropy
到tf.keras.losses.sparse_categorical_crossentropy
),但是问题仍然存在。我正在使用tensorflow-gpu 2.0
。
编辑:这是我用来生成数据集的类:
import tensorflow as tf
from glob import glob
import os
from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import ImageDataGenerator
这是我的代码:
import math
import numpy as np
import cv2
"""
This generator is an abstract class as
"""
class TfrecordLoader(Sequence):
def __init__(self, dataset_path, batch_size=10):
# Load the tfrecords
video_paths = glob(os.path.join(dataset_path,"*.tfrecord"))
self._dataset_len = len(video_paths)
# Deserialize the tfrecords based on the below format
def _parse_dicom_feature_record_function(serialized_data):
# Deserialize the data
dicom_feature_description = {
'NUM_FRAMES': tf.io.FixedLenFeature([], tf.int64),
'HEIGHT': tf.io.FixedLenFeature([], tf.int64),
'WIDTH': tf.io.FixedLenFeature([], tf.int64),
'CHANNEL_DEPTH': tf.io.FixedLenFeature([], tf.int64),
# 'LVEF': tf.io.FixedLenFeature([], dtype=tf.float32),
'IMAGE_QUALITY': tf.io.FixedLenFeature([], tf.string),
'RECORD_NAME': tf.io.FixedLenFeature([], tf.string),
'VIEW_TYPE': tf.io.FixedLenFeature([], tf.string),
'FRAMES_ARRAY': tf.io.FixedLenFeature([], tf.string),
}
return tf.io.parse_single_example(serialized_data, dicom_feature_description)
self._video_dataset = tf.data.TFRecordDataset(video_paths).map(_parse_dicom_feature_record_function)
# The actual dataset used by this iterator randomly samples the frames of the passed data
self._batch_size = batch_size
self._iterator = self._create_iterator()
return
"""
Create iterator that iterates through loaded tfrecords. Note, this is intended to be overloaded for more
specific traigins cases based on this dataset
"""
def _create_iterator(self):
# Return an iterator to that data
return iter(self._video_dataset.shuffle(self._dataset_len).batch(self._batch_size))
"""
Give the length of and epoch
"""
def __len__(self):
return math.floor(self._dataset_len / self._batch_size)
"""
Returns one batch of data
"""
def __getitem__(self, idx):
return next(self._iterator)
"""
Called at the end of each epic to select a new shuffled balanced dataset
"""
def on_epoch_end(self):
print ('\n\n ------------------------- EPOCH END ------------------------ \n\n')
self._iterator = self._create_iterator()
return
"""
Convenience method for converting a tfrecord from this dataset into a traditional python dictionary
"""
def convert_tfrecorddata_to_dict(self, record, batch_index):
converted_dict = {}
converted_dict['NUM_FRAMES'] = int(record['NUM_FRAMES'][batch_index])
converted_dict['HEIGHT'] = int(record['HEIGHT'][batch_index])
converted_dict['WIDTH'] = int(record['WIDTH'][batch_index])
converted_dict['CHANNEL_DEPTH'] = int(record['CHANNEL_DEPTH'][batch_index])
# converted_dict['LVEF'] = float(record['LVEF'][batch_index])
converted_dict['IMAGE_QUALITY'] = record['IMAGE_QUALITY'][batch_index].numpy().decode()
converted_dict['RECORD_NAME'] = record['RECORD_NAME'][batch_index].numpy().decode()
converted_dict['VIEW_TYPE'] = record['VIEW_TYPE'][batch_index].numpy().decode()
# Convert the raw frames data from a string to a flat array of floats
raw_frames_data = tf.io.decode_raw(record['FRAMES_ARRAY'][batch_index], tf.uint8)
converted_dict['FRAMES_ARRAY'] = tf.reshape(raw_frames_data, [converted_dict['NUM_FRAMES'],
converted_dict['HEIGHT'],
converted_dict['WIDTH'],
converted_dict['CHANNEL_DEPTH']]).numpy()
return converted_dict
def main():
print (tf.__version__)
record_generator = TfrecordLoader('test_tfrecords/',10)
# Cycle through the dataset
for record_batch in record_generator:
# convert the first record in the batch
converted_record = record_generator.convert_tfrecorddata_to_dict(record_batch,0)
# view the first frame of the video
squeezed_img = converted_record['FRAMES_ARRAY'][0].squeeze()
cv2.imshow(converted_record['RECORD_NAME'], squeezed_img)
cv2.waitKey(200)
cv2.destroyAllWindows()
return
if __name__ == '__main__':
main()`