Keras未将GPU用于K.get_updates和K.function

时间:2019-05-03 23:43:40

标签: tensorflow keras gpu

我有一个用喀拉拉邦写的GAN。它有2个网络,一个使用keras自定义损失函数并使用train_on_batch更新,这可以正常工作并使用GPU。第二个网络使用K.get_updates和K.function更新,它可以工作,但似乎在CPU而非GPU上进行训练。

训练第一个网络时,GPU负载达到最大,然后在训练第二个网络时,GPU负载降为0。

如果我将网络改回使用train_on_batch进行训练,则它将使用GPU。但是,我需要get_updates和K.function的功能。

这是我的训练功能:

def train_combo(self):
    #input = K.placeholder(shape=[None, 100])
    var = K.placeholder(shape=[None,1])

    loss = K.sqrt(K.square(K.std(self.combo.output)- K.std(var)))
    optimizer = Adam(lr=0.001)
    updates = optimizer.get_updates(self.combo.trainable_weights,[], loss)

    train = K.function([self.combo.input,var], [loss], updates=updates)
    return train

这就是我所说的:

train = self.train_combo()
loss = train([vectors ,var])

我希望它可以在GPU上运行,并且似乎正在CPU上运行

0 个答案:

没有答案