我正在尝试遵循this tutorial并使用CUDA后端进行简单的c ++扩展。
我的CPU实现似乎工作正常。
我在查找示例和文档时遇到了麻烦(似乎事情在不断变化)。
具体地说,
我看到pytorch cuda函数得到THCState *state
参数-该参数来自何处?我如何也可以为我的功能获得state
?
例如,在tensor.cat
的cuda实现中:
void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension)
但是,当从python调用tensor.cat()
时,不提供任何state
参数,pytorch会在“幕后”提供它。 pytorch如何提供此信息以及如何获取?
state
然后转换为cudaStream_t stream = THCState_getCurrentStream(state);
由于某些原因,THCState_getCurrentStream
不再定义了吗?如何从stream
中获取state
?
我也尝试在pytorch论坛上提问-到目前为止没有任何效果。
答案 0 :(得分:2)
已弃用(无文档!) 看这里: https://github.com/pytorch/pytorch/pull/14500
简而言之:使用at::cuda::getCurrentCUDAStream()