在Keras的Lambda层内执行一种热编码以避免内存问题:问题

时间:2019-07-15 04:05:10

标签: python tensorflow keras

我已经在喀拉拉邦建立了一个顺序模型,用于生成音乐序列。很简单,具有LSTM和密集的softmax。我有333种可能的音乐事件

我知道model.fit()需要内存中的所有训练数据,如果它是一个热编码的话,这是一个问题。因此,我给模型一个整数作为输入,将其转换为Lambda层中的一种热编码,然后使用稀疏分类交叉熵进行损失。因为每一批都将即时转换为一种热编码,所以我认为这可以解决我的内存问题。但是,相反,它在试穿开始时就挂起了,即使小批量也充满了我的记忆。显然,鉴于我是新手,所以我对keras的工作方式并不了解,这不足为奇(请注意,请指出我的代码中任何过于幼稚的内容)。

1)幕后发生了什么?我不了解的喀拉拉邦是什么?似乎keras会继续进行,并在进行任何训练之前在我所有的训练示例上运行Lambda层。

2)我该如何解决这个问题,并让keras真正做到这一点呢?我可以使用目前正在使用的model.fit()来解决它,还是需要model.fit_generator()来解决它,对我来说看起来可以很容易地解决这个问题?

这是我的一些代码:

def musicmodel(Tx, n_a, n_values):
"""
Arguments:
Tx -- length of a sequence in the corpus
n_a -- the number of activations used in our model (for the LSTM)
n_values -- number of unique values in the music data 

Returns:
model -- a keras model
"""

# Define the input with a shape 
X = Input(shape=(Tx,))

# Define s0, initial hidden state for the decoder LSTM
a0 = Input(shape=(n_a,), name='a0')
c0 = Input(shape=(n_a,), name='c0')
a = a0
c = c0

# Create empty list to append the outputs to while iterating
outputs = []

# Step 2: Loop
for t in range(Tx):

    # select the "t"th time step from X. 
    x = Lambda(lambda x: x[:,t])(X)
    # We need the class represented in one hot fashion:
    x = Lambda(lambda x: tf.one_hot(K.cast(x, dtype='int32'), n_values))(x)
    # We then reshape x to be (1, n_values)
    x = reshapor(x)
    # Perform one step of the LSTM_cell
    a, _, c = LSTM_cell(x, initial_state=[a, c])
    # Apply densor to the hidden state output of LSTM_Cell
    out = densor(a)
    # Add the output to "outputs"
    outputs.append(out)

# Step 3: Create model instance
model = Model(inputs=[X,a0,c0],outputs=outputs)

return model

然后我拟合我的模型:

model = musicmodel(Tx, n_a, n_values)

opt = Adam(lr=0.01, beta_1=0.9, beta_2=0.999, decay=0.01)

model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

a0 = np.zeros((m, n_a))
c0 = np.zeros((m, n_a))

model.fit([X, a0, c0], list(Y), validation_split=0.25, epochs=600, verbose=2, batch_size=4)

0 个答案:

没有答案