如何在Keras中缓存图层激活?

时间:2019-01-29 23:26:49

标签: python keras

我训练了一个神经网络,其中第一层在Keras中具有固定权重(不可训练)。

在训练期间,这些层执行的计算非常密集。为每个输入缓存层激活并在下一个时期传递相同的输入数据时重复使用它们是很有意义的,以节省计算时间。

是否可以在Keras中实现这种行为?

1 个答案:

答案 0 :(得分:0)

您可以将模型分为两个不同的模型。例如,以下代码段x_对应于您的中间激活:

from keras.models import Model
from keras.layers import Input, Dense
import numpy as np


nb_samples = 100
in_dim = 2
h_dim = 3
out_dim = 1

a = Input(shape=(in_dim,))
b = Dense(h_dim, trainable=False)(a)
model1 = Model(a, b)
model1.compile('sgd', 'mse')

c = Input(shape=(h_dim,))
d = Dense(out_dim)(c)
model2 = Model(c, d)
model2.compile('sgd', 'mse')


x = np.random.rand(nb_samples, in_dim)
y = np.random.rand(nb_samples, out_dim)
x_ = model1.predict(x)  # Shape=(nb_samples, h_dim)

model2.fit(x_, y)