训练损失减少,但是当我使用tensorflow.data api时模型无法学习

时间:2020-01-27 12:35:44

标签: python tensorflow2.0 tensorflow-datasets

我正在训练一个有4个班级的分类模型。当我将数据生成器与numpy数组一起使用时,模型可以很好地训练,损失减少,并且预测效果很好。

但是,当我针对完全相同数量的具有相同参数的历时训练相同模​​型时,训练损失的确减少到几乎相同的值,但是即使在训练数据集上,预测精度也低于50%。

我浏览了以下链接:
Keras model failed to learn anything after changing to use tf.data api
https://github.com/tensorflow/tensorflow/issues/22190

我对损失函数进行了调整(从categorical_crossentropytf.keras.losses.sparse_categorical_crossentropy),但是问题仍然存在。我正在使用tensorflow-gpu 2.0

编辑:这是我用来生成数据集的类:

import tensorflow as tf
from glob import glob
import os
from tensorflow.keras.utils import Sequence
from tensorflow.keras.preprocessing.image import ImageDataGenerator

这是我的代码:

import math
import numpy as np
import cv2



"""
This generator is an abstract class as 
"""
class TfrecordLoader(Sequence):

    def __init__(self, dataset_path, batch_size=10):

        # Load the tfrecords
        video_paths = glob(os.path.join(dataset_path,"*.tfrecord"))
        self._dataset_len = len(video_paths)

        # Deserialize the tfrecords based on the below format
        def _parse_dicom_feature_record_function(serialized_data):

            # Deserialize the data
            dicom_feature_description = {
                'NUM_FRAMES': tf.io.FixedLenFeature([], tf.int64),
                'HEIGHT': tf.io.FixedLenFeature([], tf.int64),
                'WIDTH': tf.io.FixedLenFeature([], tf.int64),
                'CHANNEL_DEPTH': tf.io.FixedLenFeature([], tf.int64),
#                 'LVEF': tf.io.FixedLenFeature([], dtype=tf.float32),
                'IMAGE_QUALITY': tf.io.FixedLenFeature([], tf.string),
                'RECORD_NAME': tf.io.FixedLenFeature([], tf.string),
                'VIEW_TYPE': tf.io.FixedLenFeature([], tf.string),
                'FRAMES_ARRAY': tf.io.FixedLenFeature([], tf.string),
            }
            return tf.io.parse_single_example(serialized_data, dicom_feature_description)

        self._video_dataset = tf.data.TFRecordDataset(video_paths).map(_parse_dicom_feature_record_function)


        # The actual dataset used by this iterator randomly samples the frames of the passed data
        self._batch_size = batch_size
        self._iterator = self._create_iterator()
        return

    """
    Create iterator that iterates through loaded tfrecords.  Note, this is intended to be overloaded for more 
    specific traigins cases based on this dataset
    """
    def _create_iterator(self):
        # Return an iterator to that data
        return iter(self._video_dataset.shuffle(self._dataset_len).batch(self._batch_size))

    """
    Give the length of and epoch
    """
    def __len__(self):
        return math.floor(self._dataset_len / self._batch_size)

    """
    Returns one batch of data
    """
    def __getitem__(self, idx):
        return next(self._iterator)

    """
    Called at the end of each epic to select a new shuffled balanced dataset
    """
    def on_epoch_end(self):
        print ('\n\n ------------------------- EPOCH END ------------------------ \n\n')
        self._iterator = self._create_iterator()
        return

    """
    Convenience method for converting a tfrecord from this dataset into a traditional python dictionary
    """
    def convert_tfrecorddata_to_dict(self, record, batch_index):
        converted_dict = {}
        converted_dict['NUM_FRAMES'] = int(record['NUM_FRAMES'][batch_index])
        converted_dict['HEIGHT'] = int(record['HEIGHT'][batch_index])
        converted_dict['WIDTH'] = int(record['WIDTH'][batch_index])
        converted_dict['CHANNEL_DEPTH'] = int(record['CHANNEL_DEPTH'][batch_index])
#         converted_dict['LVEF'] = float(record['LVEF'][batch_index])
        converted_dict['IMAGE_QUALITY'] = record['IMAGE_QUALITY'][batch_index].numpy().decode()
        converted_dict['RECORD_NAME'] = record['RECORD_NAME'][batch_index].numpy().decode()
        converted_dict['VIEW_TYPE'] = record['VIEW_TYPE'][batch_index].numpy().decode()

        # Convert the raw frames data from a string to a flat array of floats
        raw_frames_data = tf.io.decode_raw(record['FRAMES_ARRAY'][batch_index], tf.uint8)
        converted_dict['FRAMES_ARRAY'] = tf.reshape(raw_frames_data, [converted_dict['NUM_FRAMES'],
                                                                      converted_dict['HEIGHT'],
                                                                      converted_dict['WIDTH'],
                                                                      converted_dict['CHANNEL_DEPTH']]).numpy()

        return converted_dict


def main():
    print (tf.__version__)
    record_generator =  TfrecordLoader('test_tfrecords/',10)

    # Cycle through the dataset
    for record_batch in record_generator:
        # convert the first record in the batch
        converted_record = record_generator.convert_tfrecorddata_to_dict(record_batch,0)

        # view the first frame of the video
        squeezed_img = converted_record['FRAMES_ARRAY'][0].squeeze()
        cv2.imshow(converted_record['RECORD_NAME'], squeezed_img)
        cv2.waitKey(200)
        cv2.destroyAllWindows()

    return



if __name__ == '__main__':
    main()`

0 个答案:

没有答案