在10 GB的GPU RAM中安装MultiHeadAttention模型

时间:2019-04-29 16:32:51

标签: python tensorflow keras

我已经编写了一个自定义的自我注意模型,用于在喀拉拉邦进行序列标记。

import keras.layers as ll
from keras import Model
from keras_pos_embd import TrigPosEmbedding
from keras_multi_head import MultiHeadAttention

inputs = ll.Input(shape=(None,))
x = ll.Embedding(10000, 1024)(inputs)
x = TrigPosEmbedding(mode='add')(x)
x = MultiHeadAttention(head_num=8)(x)
x = ll.Dense(units = 512, activation='relu')(x)
x = ll.Dense(units = 4, activation='softmax')(x)
outputs = x
model = Model(inputs, outputs)
model.summary()

但是,此模型过于占用内存。我正在尝试在很长的序列(长度为20000)上进行训练,当尝试对其进行训练时,当尝试分配形状为[16,20000,20000]的张量时,它给了我一个OOM(根据我的计算,仅分配此张量将需要> 150GB的RAM!)。

我尝试将batch_size减小为1,但是它仍然拒绝容纳在内存中。

我需要以某种方式对其进行修改,以使其适合10 GB RAM GPU

我该如何处理?

0 个答案:

没有答案