TensorFlow:从numpy数组获取3D CNN的输入批处理

时间:2019-12-04 21:08:15

标签: python numpy tensorflow keras

我有3D图像(tiff)数据和文件夹中的每个卷。我想读取数据并为卷积网络制作批处理张量。我可以将数据读取为numpy数组,但不知道如何为CNN输入批处理张量。这是我的代码

import os
import tensorflow as tf
import numpy as np
from skimage import io
from matplotlib import pyplot as plt
from pathlib import Path
data_dir = 'C:/Users/myname/Documents/Projects/Segmentation/DeepLearning/L-net/data/'
data_folders = os.listdir(data_dir)
train_input = []
train_output = []
test_input = []
test_output = []

for idx, folder in enumerate(data_folders):
        im = io.imread(data_dir+folder+'/f0.tiff')
        im = im/im.max()
        train_input.append(tf.convert_to_tensor(im, dtype=tf.float32))
        im = io.imread(data_dir+folder+'/g0.tiff')
        im = im/im.max()
        train_output.append(tf.convert_to_tensor(im, dtype=tf.float32))

由于我对CNN使用3D滤镜,因此输入应为5D tesnor。有人可以帮我弄这个吗?谢谢。

1 个答案:

答案 0 :(得分:0)

采用这种方法,您必须立即将所有数据加载到内存中,并且还必须注意所有方面。我建议使用Keras flow_from_directorygenerators。 Keras具有此类ImageDataGenerator的类,该类使用户可以从目录执行图像收集,将所有图像更改为所需的大小,对它们进行随机播放,...。您可以在其网站上找到文档here

  

下载火车数据集和测试数据集,将它们提取到2个不同的文件夹中,分别命名为“ train”和“ test”。火车文件夹应包含“ n”个文件夹,每个文件夹均包含各自类别的图像。例如,在“狗与猫”数据集中,火车文件夹应具有2个文件夹,分别是“狗”和“猫”,其中包含各自的图像。

这是有关如何为模型输入创建数据集的示例:

train_generator = train_datagen.flow_from_directory(
    directory=r"C:/Users/myname/Documents/Projects/Segmentation/DeepLearning/L-net/data/",
    target_size=(224, 224), # the size of your input images
    color_mode="rgb", # could be grayscale or rgb
    batch_size=32, # Number of images in each batsh
    class_mode="categorical",
    shuffle=True, # Whether to shuffle the images or not
    seed=42 # Random seed for applying random image augmentation
)

您可以像这样进行训练:

STEP_SIZE_TRAIN=train_generator.n//train_generator.batch_size
STEP_SIZE_VALID=valid_generator.n//valid_generator.batch_size
model.fit_generator(generator=train_generator,
                    steps_per_epoch=STEP_SIZE_TRAIN,
                    validation_data=valid_generator,
                    validation_steps=STEP_SIZE_VALID,
                    epochs=10
)