我面临一个非常奇怪的问题。我有一个网络,归结为以下“中央”代码:
# COSINE
proj = tf.multiply( proj, cosine_w, name = 'cosine-weighting' )
# PARKER
proj = tf.multiply( proj, parker_w, name = 'parker-weighting' )
# RAMLAK
s = config.proj_shape
proj = tf.reshape( proj, [ s.N, 1, s.H, s.W ] )
proj = tf.nn.conv2d(
input = proj,
filter = kernel,
strides = [ 1, 1, 1, 1 ],
padding = 'SAME',
data_format = 'NCHW',
name = 'ramlak-filter'
)
proj = tf.reshape( proj, config.proj_shape.toNCHW() )
# BACKPROJECTION
volume = backproject(
projections = proj,
# other arguments, which are attrs in the user defined op
)
我在proj
中得到了一些投影数据,这是一个N x H x W
张量(其中N
是投影数)。然后将该数据分两个阶段加权,然后用一维滤波器内核进行过滤。请注意,我不希望不同的投影图像(N
维度)具有不同的权重。因此,我将proj
重塑为通道维度中的大小为1,并将投影图像“解释”为批处理中的不同图像。 backproject函数是一个自定义张量流操作,用c ++ / cuda实现,带有一个注册的渐变。
一切都适用于前锋传球。但是,如果我尝试计算滤波器内核的梯度,例如由
tf.gradients( volume, kernel, volume )
如果收到以下错误:
NotFoundError (see above for traceback): No algorithm without scratch worked!
[[Node: gradients/LAReconstructor_1/LAReconstructor/ramlak-filter_grad/Conv2DBackpropFilter = Conv2DBackpropFilter[T=DT_FLOAT, _class=["loc:@LAReconstructor_1/LAReconstructor/ramlak-filter"], data_format="NCHW", padding="SAME", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true, _device="/job:localhost/replica:0/task:0/gpu:0"](LAReconstructor_1/LAReconstructor/Reshape, gradients/LAReconstructor_1/LAReconstructor/ramlak-filter_grad/Shape_1, gradients/LAReconstructor_1/LAReconstructor/Reshape_1_grad/Reshape)]]
我试图提供一些重现错误的最小例子,但我无法在这么小的例子中重现它。我已经检查了渐变w.r.t到proj
,它符合我的期望。
有没有人知道这里会出现什么问题?
编辑:
我刚刚发现了一个产生相同错误的最小例子:
import tensorflow as tf
proj = tf.Variable( tf.random_normal([720,1,400,600], stddev = 2) )
kernel = tf.Variable( tf.random_normal([1, 401, 1, 1], stddev = .5), trainable = True )
proj = tf.nn.conv2d(
input = proj,
filter = kernel,
strides = [ 1, 1, 1, 1 ],
padding = 'SAME',
data_format = 'NCHW',
name = 'ramlak-filter'
)
grad = tf.gradients( proj, kernel, proj )
with tf.Session() as sess:
sess.run( tf.global_variables_initializer() )
print( sess.run( grad ) )
似乎与proj
的大小有关。如果我将其更改为[100, 1, 400, 600]
,则错误消失。但实际上我需要这么大的批量。有什么想法吗?
答案 0 :(得分:0)
与此同时,这已被确认为一个错误,并提升到tensorflow:https://github.com/tensorflow/tensorflow/issues/11327