与C.gather

时间:2017-08-14 01:26:29

标签: cntk

尝试测试CNTK时出错。

我正在尝试使用input_variable作为索引来切片参数。使用C.gather进行切片会导致backprop进程出现内存错误。

在所有cntk2环境中都会出现错误,例如CPU,GPU,Docker,本地安装。

错误消息和Callstack

  

RuntimeError:CUBLAS失败11:CUBLAS_STATUS_MAPPING_ERROR; GPU = 0;   hostname = ....; expr = cublasGetMatrix((int)numRows,(int)   numCols,sizeof(ElemType),Data(),(int)GetNumRows(),dst,(int)   colStride)

     

[CALL STACK]         Microsoft :: MSR :: CNTK :: CudaTimer ::停止        - Microsoft :: MSR :: CNTK :: Matrix :: CopySection        - Microsoft :: MSR :: CNTK :: Matrix :: AssignValuesOf        - CNTK :: NDArrayView :: CopyFrom        - CNTK :: NDArrayView :: NDArrayView
       - CNTK :: TrainingParameterSchedule :: Serialize        - CNTK :: DictionaryValue :: Save        - CNTK :: Trainer :: SummarizeTrainingProgress        - PyInit__cntk_py        - PyCFunction_Call        - PyEval_GetFuncDesc        - PyEval_EvalFrameEx        - PyEval_GetFuncDesc(x2)        - PyEval_EvalFrameEx(x2)

代码

x = input_val[:-2]
p1 = input_val[-2]
p2 = input_val[-1]

activator = relu

W1 = C.Parameter((slices,input_dim,hidden_layers_dim), init=C.glorot_normal(), name='W1')
b1 = C.Parameter((slices,hidden_layers_dim), init=0, name='b1')
W2 = C.Parameter((slices,hidden_layers_dim,hidden_layers_dim), init=C.glorot_normal(), name='W2')
b2 = C.Parameter((slices,hidden_layers_dim), init=0, name='b2')
W3 = C.Parameter((slices,hidden_layers_dim,output_dim), init=C.glorot_normal(), name='W3')
b3 = C.Parameter((slices,output_dim), init=0, name='b3')

W11 = C.gather(W1, p1)
b11 = C.gather(b1, p1)
W1x = C.reshape(W11, (input_dim,hidden_layers_dim))
b1x = C.reshape(b11, (hidden_layers_dim,))

W21 = C.gather(W2, p1)
b21 = C.gather(b2, p1)
W2x = C.reshape(W21, (hidden_layers_dim,hidden_layers_dim))
b2x = C.reshape(b21, (hidden_layers_dim,))

W31 = C.gather(W3, p1)
b31 = C.gather(b3, p1)
W3x = C.reshape(W31, (hidden_layers_dim,output_dim))
b3x = C.reshape(b31, (output_dim,))

x = activator(C.times(x, W1x) + b1x)
x = activator(C.times(x, W2x) + b2x)
x = C.times(x, W3x) + b3x

1 个答案:

答案 0 :(得分:0)

我无法用最新的大师重现这一点。这很可能是在CNTK 2.1发布后立即修复的错误。下一个版本(2.2)将在2017年9月15日左右。如果问题仍然存在,在升级到2.2后,打开github issue可能是解决此问题的正确方法。