我想定义一个自定义渐变色,我需要使用一个占位符变量来定义渐变色。简单地将占位符作为变量传递似乎不起作用,我认为这与返回的函数所需的函数签名有关,但我无法解读错误消息。
尝试#1
@tf.custom_gradient
def sdf(placeholder_sdf, placeholder_sdf_gradient, point):
"""
point is a 1x2 numpy array [[x], [y]]
this function is a temporary hack just to test things
"""
resolution = np.array([[1], [1]], dtype=np.float32)
integer_coordinates = tf.cast(tf.divide(point, resolution), dtype=tf.int32)
# blindly assume the point is within our grid
sdf_value = tf.gather_nd(placeholder_sdf, integer_coordinates)
# noinspection PyUnusedLocal
def __sdf_gradient_func(dy):
sdf_gradient = tf.gather_nd(placeholder_sdf_gradient, integer_coordinates)
return sdf_gradient
return sdf_value, __sdf_gradient_func
给我错误
Traceback (most recent call last):
File "./scripts/linear_invertible_constraint_model_runner.py", line 127, in <module>
args.func(args)
File "./scripts/linear_invertible_constraint_model_runner.py", line 18, in train
L=args.L, dt=DT)
File "/home/pmitrano/catkin_ws/src/link_bot/link_bot_notebooks/src/link_bot_notebooks/linear_invertible_constraint_model.py", line 130, in __init__
self.opt = tf.train.AdamOptimizer().minimize(self.loss, global_step=self.global_step)
File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/training/optimizer.py", line 400, in minimize
grad_loss=grad_loss)
File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/training/optimizer.py", line 519, in compute_gradients
colocate_gradients_with_ops=colocate_gradients_with_ops)
File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py", line 630, in gradients
gate_gradients, aggregation_method, stop_gradients)
File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py", line 821, in _GradientsHelper
_VerifyGeneratedGradients(in_grads, op)
File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py", line 323, in _VerifyGeneratedGradients
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
ValueError: Num gradients 2 generated for op name: "IdentityN"
op: "IdentityN"
input: "GatherNd"
input: "sdf"
input: "sdf_gradient"
input: "constraint_1"
attr {
key: "T"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
type: DT_FLOAT
}
}
}
attr {
key: "_gradient_op_type"
value {
s: "CustomGradient-78"
}
}
do not match num inputs 4
我不希望相对于占位符定义渐变,所以我认为也许不能将它们用作函数的参数?如何正确定义呢?