我有一个用喀拉拉邦写的GAN。它有2个网络,一个使用keras自定义损失函数并使用train_on_batch更新,这可以正常工作并使用GPU。第二个网络使用K.get_updates和K.function更新,它可以工作,但似乎在CPU而非GPU上进行训练。
训练第一个网络时,GPU负载达到最大,然后在训练第二个网络时,GPU负载降为0。
如果我将网络改回使用train_on_batch进行训练,则它将使用GPU。但是,我需要get_updates和K.function的功能。
这是我的训练功能:
def train_combo(self):
#input = K.placeholder(shape=[None, 100])
var = K.placeholder(shape=[None,1])
loss = K.sqrt(K.square(K.std(self.combo.output)- K.std(var)))
optimizer = Adam(lr=0.001)
updates = optimizer.get_updates(self.combo.trainable_weights,[], loss)
train = K.function([self.combo.input,var], [loss], updates=updates)
return train
这就是我所说的:
train = self.train_combo()
loss = train([vectors ,var])
我希望它可以在GPU上运行,并且似乎正在CPU上运行