修改梯度输出数量以进行自定义梯度(tensorflow)

时间:2018-07-20 16:40:02

标签: python tensorflow

类似于此问题(ValueError: Num gradients 1 generated for op name: "mask/Mask"),我也得到了梯度数量和输入数量之间的不匹配。我还只需要第一个输入的渐变(在这种情况下为x),因此就需要第二个输入的“假渐变”(在此情况下为y)。最终有人可以启发我如何更改梯度输出(还要确保忽略第二个输入的梯度)

def tensorflow_layer_grad_impl(x, z, dy, name):

    model_format = int(z.shape.dims[0])

    with tf.name_scope(name):
        # Validate the input/output shape
        x_shape = x.get_shape()
        dy_shape = dy.get_shape()
        try:
            # Lazy check if the first dimension is dynamic
            n_x = int(x_shape[0])
            fixed_size = True
        except TypeError:
            n_x = x_shape[0]
            fixed_size = False

        in_shape = (n_x,) + space_shape(op.range(model_format)) + (1,)
        out_shape = (n_x,) + space_shape(op.domain(model_format)) + (1,)

        assert x_shape[1:] == space_shape(op.domain(model_format)) + (1,)
        assert dy_shape[1:] == space_shape(op.range(model_format)) + (1,)


        def _impl(x, z, dy):
            # Validate the shape of the given input
            if fixed_size:
                x_out_shape = out_shape
                assert x.shape == out_shape
                assert dy.shape == in_shape
            else:
                x_out_shape = (x.shape[0],) + out_shape[1:]
                assert x.shape[1:] == out_shape[1:]
                assert dy.shape[1:] == in_shape[1:]

            # Evaluate the operator on all inputs in the batch.
            out = np.empty(x_out_shape, op.domain(z).dtype)
            out_element = op.domain.element() # defining the shape
            for i in range(x_out_shape[0]):
                xi = x[i, ..., 0]
                dyi = dy[i, ..., 0]
                op.transpose(dyi, z, out=out_element) 
                out[i, ..., 0] = np.asarray(out_element)
                x = np.asarray(out_element)

            # Rescale the domain/range according to the weighting since
            # tensorflow does not have weighted spaces.
            try:
                dom_weight = odl_op.angle_domain(angle_array).weighting.const
            except AttributeError:
                dom_weight = 1.0

            try:
                ran_weight = odl_op.angle_range(angle_array).weighting.const
            except AttributeError:
                ran_weight = 1.0

            scale = dom_weight / ran_weight
            out *= scale

            return out

        # wrapping the function and calling it
        with ops.name_scope(name + '_pyfunc', values=[x, angle_array, dy]) as name_call:
            output = py_func(_impl,
                             [x, z, dy],
                             [tf.float32],
                             name=name_call,
                             stateful=False)

            output = output[0]
            output.set_shape(out_shape)
            return output

然后我们有相应的py_func:

def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
    if grad is None:
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
    else:
        if stateful:
            override_name = 'PyFunc'
        else:
            override_name = 'PyFuncStateless'

        # Need to generate a unique name to avoid duplicates:
        rnd_name = override_name + 'Grad' + str(uuid.uuid4())

        tf.RegisterGradient(rnd_name)(grad)
        g = tf.get_default_graph()

        with g.gradient_override_map({override_name: rnd_name}):
            return tf.py_func(func, inp, Tout, stateful=stateful,
                              name=name)

0 个答案:

没有答案