def g():
f_train=open('train.txt','r')
f_label=open('type_train.txt','r')
train=[]
label=[]
count=0
num=0
for i,j in zip(f_train,f_label):
train.append(np.array(i.strip().split(',')))
label.append(np.array(j.strip().split()))
count+=1
if count==200 :
count=0
train=np.array(train,int)
train=pad_sequences(train,140)
label=np.array(label,int)
yield (train,label)
num+=1
#loss=model.train_on_batch=(train,label)
#print (num)
train=[]
label=[]
if num==19000:
break
f_train.close()
f_label.close()
model = Sequential()
model.add(Embedding(215625 + 1,20,input_length=140,trainable=True))
#del word_index
model.add(LSTM(30))
model.add(Dense(6,activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
model.fit_generator(g(),steps_per_epoch=19000,epochs=3)
抱歉我的英语不好。 感谢您的任何帮助。 我的内存为16G,运行在6G gpu上。我找不到发电机的任何问题。使用theano后端。
答案 0 :(得分:0)
是否与以下已知问题有关? https://github.com/fchollet/keras/issues/3675
该问题的基本建议是什么,以及对我有用的是将pickle_safe=False
添加到fit_generator
来电。