自定义图层中的Tensorflow自定义渐变

时间:2020-07-02 21:00:09

标签: python tensorflow

我正在使用自定义渐变设置自定义图层。输入是单个2D张量形状(?,2)。输出也是一个单一的二维张量,形状为(?,2)。

我正在努力理解这些对象的行为。我从文档中收集到的是,对于给定的输入,渐变将与输出具有相同的形状,并且我需要为每个输入返回一个渐变列表。我一直假设由于我的输入看起来像(?,2)而我的输出看起来像(?,2),所以grad函数应该返回一个长度为2的列表:[input_1_grad,input_2_grad],其中两个列表项都是具有输出形状(?,2)的张量。

这不起作用,这就是为什么我希望这里的人可以提供帮助。

这是我的错误(似乎在编译时发生):

ValueError:为操作名称生成了Num渐变3: “ custom_layer / IdentityN”操作:“ IdentityN”输入: “ custom_layer_2 / concat”输入:“ custom_layer_1 / concat” attr {键: “ T”值{ 清单{ 类型:DT_FLOAT 类型:DT_FLOAT }}} attr {键:“ _ gradient_op_type”的值{ s:“ CustomGradient-28729”}}与num输入2不匹配

另一个难题是,自定义层的输入本身也是一个自定义层(尽管没有自定义渐变)。如果有帮助,我将提供两层代码。

此外,请注意,如果我不尝试指定自定义渐变,则网络会编译并运行。但是,由于我的功能需要帮助以使其与众不同,因此我需要手动进行干预,因此拥有有效的自定义渐变至关重要。

第一个自定义图层(无自定义渐变):

class custom_layer_1(tensorflow.keras.layers.Layer):
    def __init__(self):
        super(custom_layer_1, self).__init__()
    
    def build(self, input_shape):
        self.term_1 = self.add_weight('term_1', trainable=True)
        self.term_2 = self.add_weight('term_2', trainable=True)
    
    def call(self, x):
        self.term_1 = formula in terms of x
        self.term_2 = another formula in terms of x
        
        return tf.concat([self.term_1, self.term_2], axis=1)

第二个自定义图层(具有自定义渐变):

class custom_layer_2(tensorflow.keras.layers.Layer):
    ### the inputs
    # x is the concatenation of term_1 and term_2
    def __init__(self):
        super(custom_layer_2, self).__init__()
    
    def build(self, input_shape):
        #self.weight_1 = self.add_weight('weight_1', trainable=True)
        #self.weight_2 = self.add_weight('weight_2', trainable=True)
    
    def call(self, x):
        return custom_function(x)

自定义功能:

@tf.custom_gradient
def custom_function(x):
    ### the inputs
    # x is a concatenation of term_1 and term_2
    
    weight_1 = function in terms of x
    weight_2 = another function in terms of x
    
    ### the gradient
    def grad(dy):
        # assuming dy has the output shape of (?, 2). could be wrong.
        d_weight_1 = K.reshape(dy[:, 0], shape=(K.shape(x)[0], 1))
        d_weight_1 = K.reshape(dy[:, 1], shape=(K.shape(x)[0], 1))
        
        term_1 = K.reshape(x[:, 0], shape=(K.shape(x)[0], 1))
        term_2 = K.reshape(x[:, 1], shape=(K.shape(x)[0], 1))
        
        d_weight_1_d_term_1 = tf.where(K.equal(term_1, K.zeros_like(term_1)), K.zeros_like(term_1), -term_2 / term_1) * d_weight_1
        d_weight_1_d_term_2 = tf.where(K.equal(term_1, K.zeros_like(term_1)), K.zeros_like(term_1), 1 / term_1) * d_weight_1
        
        d_weight_2_d_term_1 = tf.where(K.equal(term_2, K.zeros_like(term_2)), K.zeros_like(term_1), 2 * term_1 / term_2) * d_weight_2
        d_weight_2_d_term_2 = tf.where(K.equal(term_2, K.zeros_like(term_2)), K.zeros_like(term_1), -K.square(term_1 / term_2)) * d_weight_2
        
        return tf.concat([d_weight_1_d_term_1, d_weight_1_d_term_2], axis=1), tf.concat([d_weight_2_d_term_1, d_weight_2_d_term_2], axis=1)
  
  return tf.concat([weight_1, weight_2], axis=1), grad

任何帮助将不胜感激!

0 个答案:

没有答案