使用fit_generator

时间:2019-05-11 18:04:19

标签: python tensorflow keras conv-neural-network

我正在尝试使用fit_generator训练类似U-net的模型。我的图片是3D(nrrd格式),我将其ID保留在数据框中。

大致来说,我的生成器如下所示:

import os
import nrrd
import numpy as np

def my_generator(samples_dataframe, path_to_sample, path_to_mask, batch_size=3):

    num_samples = len(samples_dataframe) 
    while True: 

        for i in range(0, num_samples, batch_size):

            batch_samples = samples_dataframe[i:i+batch_size]    

            X_train = []
            Y_train = []

            batch_iterator = 0

            for batch_iterator in range(len(batch_samples)):

                ID = batch_samples.at[batch_iterator,'ID']

                sample, _ = nrrd.read(os.path.join(path_to_sample, ID))                 
                mask, _ = nrrd.read(os.path.join(path_to_mask, ID)) 

                X_train.append(sample)
                Y_train.append(mask)                              

            X_train = np.stack(X_train)
            X_train = np.expand_dims(X_train, axis=-1) 
            Y_train = np.stack(Y_train)
            Y_train = np.expand_dims(Y_train, axis=-1)            

            yield  X_train, Y_train

我已经单独测试了生成器,它似乎可以正确地生产批次。然后,将其输入到fit_generator函数,如下所示:

nb_batches = 3
nb_epoch = 100
train_steps = np.ceil(len(train_dataframe)/nb_batches)
val_steps = np.ceil(len(val_dataframe)/nb_batches)

train_samples = my_generator(train_dataframe, path_to_sample, path_to_mask, nb_batches)
val_samples = my_generator(val_dataframe, path_to_sample, path_to_mask, nb_batches)

model = my_model() # 3D-unet-like model

model.fit_generator(
       train_samples, 
       steps_per_epoch = train_steps, 
       epochs = nb_epoch,
       validation_data = val_samples,
       validation_steps=val_steps)

无论图像大小或批处理大小如何,在运行fit_generator时,我总是收到“内核死亡,重新启动”的信息。

我在fit_generator中正确使用my_generator吗?我想念什么?

我正在Ubuntu的Anaconda中使用Spyder,Python 3.6,tensorflow-gpu 1.9.0,keras 2.2.4。

1 个答案:

答案 0 :(得分:0)

尝试使用调试器查看代码失败的地方(我知道这是fit_generator,但请尝试查看代码中的确切问题)。

在这里也提出了类似的问题:kernel died restarting, whenever i'm training a model