如何创建使用占位符的自定义渐变函数?

时间:2019-01-17 19:39:36

标签: tensorflow

我想定义一个自定义渐变色,我需要使用一个占位符变量来定义渐变色。简单地将占位符作为变量传递似乎不起作用,我认为这与返回的函数所需的函数签名有关,但我无法解读错误消息。

尝试#1

@tf.custom_gradient
def sdf(placeholder_sdf, placeholder_sdf_gradient, point):
    """
    point is a 1x2 numpy array [[x], [y]]
    this function is a temporary hack just to test things
    """
    resolution = np.array([[1], [1]], dtype=np.float32)
    integer_coordinates = tf.cast(tf.divide(point, resolution), dtype=tf.int32)
    # blindly assume the point is within our grid
    sdf_value = tf.gather_nd(placeholder_sdf, integer_coordinates)

    # noinspection PyUnusedLocal
    def __sdf_gradient_func(dy):
        sdf_gradient = tf.gather_nd(placeholder_sdf_gradient, integer_coordinates)
        return sdf_gradient

    return sdf_value, __sdf_gradient_func

给我错误

Traceback (most recent call last):
  File "./scripts/linear_invertible_constraint_model_runner.py", line 127, in <module>
    args.func(args)
  File "./scripts/linear_invertible_constraint_model_runner.py", line 18, in train
    L=args.L, dt=DT)
  File "/home/pmitrano/catkin_ws/src/link_bot/link_bot_notebooks/src/link_bot_notebooks/linear_invertible_constraint_model.py", line 130, in __init__
    self.opt = tf.train.AdamOptimizer().minimize(self.loss, global_step=self.global_step)
  File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/training/optimizer.py", line 400, in minimize
    grad_loss=grad_loss)
  File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/training/optimizer.py", line 519, in compute_gradients
    colocate_gradients_with_ops=colocate_gradients_with_ops)
  File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py", line 630, in gradients
    gate_gradients, aggregation_method, stop_gradients)
  File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py", line 821, in _GradientsHelper
    _VerifyGeneratedGradients(in_grads, op)
  File "/home/pmitrano/.local/lib/python2.7/site-packages/tensorflow/python/ops/gradients_impl.py", line 323, in _VerifyGeneratedGradients
    "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
ValueError: Num gradients 2 generated for op name: "IdentityN"
op: "IdentityN"
input: "GatherNd"
input: "sdf"
input: "sdf_gradient"
input: "constraint_1"
attr {
  key: "T"
  value {
    list {
      type: DT_FLOAT
      type: DT_FLOAT
      type: DT_FLOAT
      type: DT_FLOAT
    }
  }
}
attr {
  key: "_gradient_op_type"
  value {
    s: "CustomGradient-78"
  }
}
 do not match num inputs 4

我不希望相对于占位符定义渐变,所以我认为也许不能将它们用作函数的参数?如何正确定义呢?

0 个答案:

没有答案