如何在损失函数中使用tf.contrib.image.rotate?

时间:2019-07-08 15:14:18

标签: tensorflow rotation loss

我在损失函数中使用了tf.conbrib.image.rotate,并且发生了错误:

No gradients provided for any variable, check your graph for ops that do not support gradients,

我的代码是:

import tensorflow as tf

image_tensor = tf.placeholder(dtype=tf.float32, shape=[None,320,320,1])
target_tensor = tf.placeholder(dtype=tf.float32, shape=[None,320,320,1])
s = tf.concat([image_tensor, target_tensor],axis=3)
s = tf.layers.flatten(s)
w = tf.get_variable(initializer=tf.truncated_normal([204800,1], stddev=0.1),name='w')
b = tf.get_variable(initializer=tf.truncated_normal([1], stddev=0.1),name='b')
a = tf.matmul(s,w)+b
diff = tf.contrib.image.rotate(image_tensor, a[:,0], interpolation='BILINEAR') - target_tensor
loss = tf.reduce_sum(tf.square(diff))

optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)

我的张量流是:1.4.0,我的计算机是Win10。

顺便问一下,如何在tensorflow中旋转3D图像? tf.conbrib.image.rotate仅适用于2D图像。

1 个答案:

答案 0 :(得分:0)

tf.contrib.image.rotate无法优化。我找到了对空间变换网络(https://github.com/kevinzakka/spatial-transformer-network)有用的代码。