pytorch的CUDA:CUDA C ++流和状态

时间:2019-04-30 10:35:54

标签: c++ cuda pytorch

我正在尝试遵循this tutorial并使用CUDA后端进行简单的c ++扩展。
我的CPU实现似乎工作正常。

我在查找示例和文档时遇到了麻烦(似乎事情在不断变化)。

具体地说,

  1. 我看到pytorch cuda函数得到THCState *state参数-该参数来自何处?我如何也可以为我的功能获得state
    例如,在tensor.cat的cuda实现中:

    void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension)
    

    但是,当从python调用tensor.cat()时,不提供任何state参数,pytorch会在“幕后”提供它。 pytorch如何提供此信息以及如何获取?

  2. state然后转换为cudaStream_t stream = THCState_getCurrentStream(state);
    由于某些原因,THCState_getCurrentStream不再定义了吗?如何从stream中获取state

我也尝试在pytorch论坛上提问-到目前为止没有任何效果。

1 个答案:

答案 0 :(得分:2)

已弃用(无文档!) 看这里: https://github.com/pytorch/pytorch/pull/14500

简而言之:使用at::cuda::getCurrentCUDAStream()