flatten_parameters()有什么作用?

时间:2018-11-09 18:38:38

标签: pytorch

我在RNN的正向函数中看到了许多使用flatten_parameters的Pytorch示例

self.rnn.flatten_parameters()

我看到了这个RNNBase,上面写着

  

重置参数数据指针,以便它们可以使用更快的代码路径

那是什么意思?

1 个答案:

答案 0 :(得分:1)

这可能不是您问题的完整答案。但是,如果您看一下flatten_parameters的源代码,您会注意到它在

中调用了_cudnn_rnn_flatten_weight
...
NoGradGuard no_grad;
torch::_cudnn_rnn_flatten_weight(...)
...

是完成任务的功能。您会发现它实际上是在以下位置将模型的权重复制到vector<Tensor>(检查params_arr声明)中:

  // Slice off views into weight_buf
  std::vector<Tensor> params_arr;
  size_t params_stride0;
  std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf);

  MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
                    params{params_arr, params_stride0};

权重复制到

  // Copy weights
  _copyParams(weight, params);

还请注意,它们通过执行就地操作{{,用新的Reset指针来更新(或在文档中明确说是weights,如其在文档中明确指出的那样) params

中的1}}(.set_是就地操作的符号)
_

根据n2798 (draft of C++0x)

©ISO / IECN3092

  

23.3.6类模板向量

     

向量是支持随机访问迭代器的序列容器。另外,它在末尾支持(摊销)恒定时间插入和擦除操作。在中间插入和擦除需要线性时间。存储管理是自动处理的,尽管可以提供一些提示以提高效率。 向量的元素是连续存储的,这意味着如果orig_param.set_(new_param.view_as(orig_param));是向量 // Update the storage for (size_t i = 0; i < weight.size(0); i++) { for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin(); orig_param_it != weight[i].end() && new_param_it != params[i].end(); orig_param_it++, new_param_it++) { auto orig_param = *orig_param_it, new_param = *new_param_it; orig_param.set_(new_param.view_as(orig_param)); } } ,其中v是bool以外的其他类型,则它遵循{所有<T, Allocator>的{​​1}}。


在某些situations

  

UserWarning:RNN模块权重不是单个连续内存块的一部分。这意味着它们需要在每次调用时进行压缩,从而可能极大地增加内存使用量。要再次压缩权重,请调用T

他们在代码警告中明确建议人们拥有连续的内存块。