我已经编写了一个自定义的自我注意模型,用于在喀拉拉邦进行序列标记。
import keras.layers as ll
from keras import Model
from keras_pos_embd import TrigPosEmbedding
from keras_multi_head import MultiHeadAttention
inputs = ll.Input(shape=(None,))
x = ll.Embedding(10000, 1024)(inputs)
x = TrigPosEmbedding(mode='add')(x)
x = MultiHeadAttention(head_num=8)(x)
x = ll.Dense(units = 512, activation='relu')(x)
x = ll.Dense(units = 4, activation='softmax')(x)
outputs = x
model = Model(inputs, outputs)
model.summary()
但是,此模型过于占用内存。我正在尝试在很长的序列(长度为20000)上进行训练,当尝试对其进行训练时,当尝试分配形状为[16,20000,20000]的张量时,它给了我一个OOM(根据我的计算,仅分配此张量将需要> 150GB的RAM!)。
我尝试将batch_size减小为1,但是它仍然拒绝容纳在内存中。
我需要以某种方式对其进行修改,以使其适合10 GB RAM GPU
我该如何处理?