通过conv2d反向传播时出现Tensorflow错误

时间:2017-07-04 09:11:16

标签: python tensorflow

我面临一个非常奇怪的问题。我有一个网络,归结为以下“中央”代码:

# COSINE
proj = tf.multiply( proj, cosine_w, name = 'cosine-weighting' )

# PARKER
proj = tf.multiply( proj, parker_w, name = 'parker-weighting' )

# RAMLAK
s = config.proj_shape
proj = tf.reshape( proj, [ s.N, 1, s.H, s.W ] )
proj = tf.nn.conv2d(
       input = proj,
       filter = kernel,
       strides = [ 1, 1, 1, 1 ],
       padding = 'SAME',
       data_format = 'NCHW',
       name = 'ramlak-filter'
)
proj = tf.reshape( proj, config.proj_shape.toNCHW() )

# BACKPROJECTION
volume = backproject(
       projections = proj,
       # other arguments, which are attrs in the user defined op
)

我在proj中得到了一些投影数据,这是一个N x H x W张量(其中N是投影数)。然后将该数据分两个阶段加权,然后用一维滤波器内核进行过滤。请注意,我不希望不同的投影图像(N维度)具有不同的权重。因此,我将proj重塑为通道维度中的大小为1,并将投影图像“解释”为批处理中的不同图像。 backproject函数是一个自定义张量流操作,用c ++ / cuda实现,带有一个注册的渐变。

一切都适用于前锋传球。但是,如果我尝试计算滤波器内核的梯度,例如由

tf.gradients( volume, kernel, volume )

如果收到以下错误:

NotFoundError (see above for traceback): No algorithm without scratch worked!
 [[Node: gradients/LAReconstructor_1/LAReconstructor/ramlak-filter_grad/Conv2DBackpropFilter = Conv2DBackpropFilter[T=DT_FLOAT, _class=["loc:@LAReconstructor_1/LAReconstructor/ramlak-filter"], data_format="NCHW", padding="SAME", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true, _device="/job:localhost/replica:0/task:0/gpu:0"](LAReconstructor_1/LAReconstructor/Reshape, gradients/LAReconstructor_1/LAReconstructor/ramlak-filter_grad/Shape_1, gradients/LAReconstructor_1/LAReconstructor/Reshape_1_grad/Reshape)]]

我试图提供一些重现错误的最小例子,但我无法在这么小的例子中重现它。我已经检查了渐变w.r.t到proj,它符合我的期望。

有没有人知道这里会出现什么问题?

编辑:

我刚刚发现了一个产生相同错误的最小例子:

import tensorflow as tf

proj = tf.Variable( tf.random_normal([720,1,400,600], stddev = 2) )
kernel = tf.Variable( tf.random_normal([1, 401, 1, 1], stddev = .5), trainable = True )
proj = tf.nn.conv2d(
    input = proj,
    filter = kernel,
    strides = [ 1, 1, 1, 1 ],
    padding = 'SAME',
    data_format = 'NCHW',
    name = 'ramlak-filter'
)
grad = tf.gradients( proj, kernel, proj )

with tf.Session() as sess:
    sess.run( tf.global_variables_initializer() )
    print( sess.run( grad ) )

似乎与proj的大小有关。如果我将其更改为[100, 1, 400, 600],则错误消失。但实际上我需要这么大的批量。有什么想法吗?

1 个答案:

答案 0 :(得分:0)

与此同时,这已被确认为一个错误,并提升到tensorflow:https://github.com/tensorflow/tensorflow/issues/11327