使用TF.data和NumPy进行数据加载之间的模型训练性能差异

时间:2020-10-25 00:21:48

标签: python numpy tensorflow tensorflow2.0 tensorflow-datasets

与该模型将NumPy数组作为输入的原始实现相比,我正在尝试使用tf.data重新实现数据加载和馈送至模型,以进行如下所示的训练功能。

用于模型训练的原始数据加载

import os
import re
from scipy import ndimage, misc
from skimage.transform import resize, rescale
from matplotlib import pyplot
import numpy as np

def train_batches(just_load_dataset=False):

    batches = 256 # Number of images to have at the same time in a batch

    batch = 0 # Number if images in the current batch (grows over time and then resets for each batch)
    batch_nb = 0 # Batch current index

    max_batches = -1 # If you want to train only on a limited number of images to finish the training even faster.
    
    ep = 4 # Number of epochs

    images = []
    x_train_n = []
    x_train_down = []
    
    x_train_n2 = [] # Resulting high res dataset
    x_train_down2 = [] # Resulting low res dataset
    
    for root, dirnames, filenames in os.walk("./dataset/"):
        for filename in filenames:
            if re.search("\.(jpg|jpeg|JPEG|png|bmp|tiff)$", filename):
                if batch_nb == max_batches: # If we limit the number of batches, just return earlier
                    return x_train_n2, x_train_down2
                filepath = os.path.join(root, filename)
                image = pyplot.imread(filepath)
                if len(image.shape) > 2:
                        
                    image_resized = resize(image, (256, 256)) # Resize the image so that every image is the same size
                    x_train_n.append(image_resized) # Add this image to the high res dataset
                    x_train_down.append(rescale(rescale(image_resized, 0.5, multichannel=True), 2.0, multichannel=True)) # Rescale it 0.5x and 2x so that it is a low res image but still has 256x256 resolution
                    batch += 1
                    if batch == batches:
                        batch_nb += 1

                        x_train_n2 = np.array(x_train_n)
                        x_train_down2 = np.array(x_train_down)
                        
                        if just_load_dataset:
                            return x_train_n2, x_train_down2
                        
                        print('Training batch', batch_nb, '(', batches, ')')
                        
                        model.fit(x_train_down2, x_train_n2,
                            epochs=ep,
                            batch_size=10,
                            shuffle=True,
                            validation_split=0.15)
                    
                        x_train_n = []
                        x_train_down = []
                    
                        batch = 0

    return x_train_n2, x_train_down2

我使用TF.data API的实现

AUTOTUNE = tf.data.experimental.AUTOTUNE

def parse_image(downsampled_path, img_path):
    image = tf.io.read_file(img_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [256, 256])

    downsampled = tf.io.read_file(downsampled_path)
    downsampled = tf.image.decode_png(downsampled, channels=3)
    downsampled = tf.image.resize(downsampled, [256, 256])
    
    return (downsampled, image)


def load_dataset(data, image_paths, batch_size, shuffle):
    if shuffle:
        # Prefetch, shuffle then batch
        data = data.cache().prefetch(AUTOTUNE).shuffle(np.random.randint(0, len(image_paths))).batch(batch_size)
    else:
        # Batch and prefetch
        data = data.cache().batch(batch_size).prefetch(AUTOTUNE)
    return data

train_data = tf.data.Dataset.from_tensor_slices((downsampled_train_images, original_train_images))
valid_data = tf.data.Dataset.from_tensor_slices((downsampled_valid_images, original_valid_images))

train_data = train_data.map(parse_image, num_parallel_calls=AUTOTUNE)
valid_data = valid_data.map(parse_image, num_parallel_calls=AUTOTUNE)

train_dataset = load_dataset(train_data, original_train_images, 32, False)
valid_dataset = load_dataset(valid_data, original_valid_images, 32, False)

history = model.fit(train_dataset,
                    epochs=10,
                    steps_per_epoch=30,
                    validation_data=valid_dataset,
                    validation_steps=2)

当我使用基于NumPy的数据作为输入时,图像基本上是256张图像的NumPy数组。在我的实现中,由于预取,自动调整批处理大小等优点,我试图使用tf.data加载图像。

问题

我面临的问题是,无论使用哪个选项,模型都会开始训练,但是使用第一个函数,第一次迭代的损失就在0.145左右,而对于拥有tf.data API的那个,损失约为15479.0或类似的数字。 使用损失为mean squared error。我无法弄清楚为什么使用tf.data API的模型训练性能之间存在差异。有人可以在这里指导/帮助我吗?谢谢:)

0 个答案:

没有答案