使用fit_generator进行内存泄漏

时间:2017-05-11 13:08:18

标签: keras

def g():
    f_train=open('train.txt','r')
    f_label=open('type_train.txt','r')
    train=[]
    label=[]
    count=0
    num=0
    for i,j in zip(f_train,f_label):

        train.append(np.array(i.strip().split(',')))
        label.append(np.array(j.strip().split()))
        count+=1
        if count==200 :
            count=0

            train=np.array(train,int)
            train=pad_sequences(train,140)
            label=np.array(label,int)
            yield (train,label)
            num+=1
            #loss=model.train_on_batch=(train,label)
            #print (num)
            train=[]
            label=[]
        if num==19000:
            break
    f_train.close()
    f_label.close()

model = Sequential()
model.add(Embedding(215625 + 1,20,input_length=140,trainable=True))
#del word_index
model.add(LSTM(30))
model.add(Dense(6,activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])


model.fit_generator(g(),steps_per_epoch=19000,epochs=3)

抱歉我的英语不好。 感谢您的任何帮助。 我的内存为16G,运行在6G gpu上。我找不到发电机的任何问题。使用theano后端。

1 个答案:

答案 0 :(得分:0)

是否与以下已知问题有关? https://github.com/fchollet/keras/issues/3675

该问题的基本建议是什么,以及对我有用的是将pickle_safe=False添加到fit_generator来电。