资源耗尽:分配具有形状的张量时发生OOM [845246,300]

时间:2018-08-01 12:42:08

标签: tensorflow keras nlp

我正在使用序列到序列语言模型,并且在更改代码以将自定义单词嵌入权重传递给Embeddings层后,尝试在gpu上进行训练时收到OOM错误。

以下是相关代码:

def create_model(word_map, X_train, Y_train, vocab_size, max_length):
    # define model
    model = Sequential()
    # get custom embedding weights as matrix
    embedding_matrix = get_weights_matrix_from_word_map(word_map)
    model.add(Embedding(len(word_map)+1, 300, weights=[embedding_matrix], input_length=max_length-1))
    model.add(LSTM(50))
    model.add(Dense(vocab_size, activation='softmax'))
    print(model.summary())
    # compile network
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.fit(X_train, Y_train, epochs=100, verbose=2)
    return model

这是来自服务器的完整错误日志:

    File "/home2/slp24/thesis/UpdatedLanguageModel_7_31.py", line 335, in create_model_2
    model.fit(X_train, Y_train, batch_size=32, epochs=1, verbose=2)  ## prev X, y
  File "/opt/python-3.4.1/lib/python3.4/site-packages/keras/models.py", line 963, in fit
    validation_steps=validation_steps)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/keras/engine/training.py", line 1682, in fit
    self._make_train_function()
  File "/opt/python-3.4.1/lib/python3.4/site-packages/keras/engine/training.py", line 990, in _make_train_function
    loss=self.total_loss)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/keras/optimizers.py", line 466, in get_updates
    m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
  File "/opt/python-3.4.1/lib/python3.4/site-packages/tensorflow/python/ops/math_ops.py", line 898, in binary_op_wrapper
    y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
  File "/opt/python-3.4.1/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 932, in convert_to_tensor
    as_ref=False)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 1022, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/tensorflow/python/ops/gradients_impl.py", line 100, in _IndexedSlicesToTensor
    value.values, value.indices, value.dense_shape[0], name=name)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/tensorflow/python/ops/gen_math_ops.py", line 5186, in unsorted_segment_sum
    num_segments=num_segments, name=name)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 3160, in create_op
    op_def=op_def)
  File "/opt/python-3.4.1/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 1625, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[845246,300] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
         [[Node: training/Adam/mul_2/y = UnsortedSegmentSum[T=DT_FLOAT, Tindices=DT_INT32, Tnumsegments=DT_INT32, _device="/job:localhost/replica:0/task:0/device:GPU:0"](training/Adam/gradients/embedding_1/Gather_grad/Reshape, training/Adam/gradients/embedding_1/Gather_grad/Reshape_1/_101, training/Adam/mul_2/strided_slice)]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

编辑:

到目前为止,我已经尝试过

  • 从batch_size = 32开始添加批处理
  • 我目前正在努力将输出类别的数量从845,286减少。我认为在计算自定义嵌入矩阵时出了点问题,特别是当我“连接”预处理期间分配的词汇标记索引和模型使用的Keras分配的y_categorical值时。

任何帮助或指导都将不胜感激!我搜索了许多类似的已发行文档,但到目前为止,还无法将这些修订应用于我的代码。谢谢

2 个答案:

答案 0 :(得分:3)

您超出了GPU的内存大小。

您可以:

  • 以较小的批次进行培训/预测
  • 或者,即使batch_size=1太大了,您也需要一个参数较少的模型。
提示,该张量(845246)的长度确实很大。那是正确的长度吗?

答案 1 :(得分:0)

我在使用Google Colab GPU时遇到了同样的问题 批处理大小为64,并且出现了此错误,在我将批处理大小减小为32后,它可以正常工作