我已经实现了Tf2 Keras图层,但是在训练时出现以下错误:
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: ada_cos_layer_1/truediv:0
我见过一些类似的帖子,但是他们的问题是Lambda
层,我没有使用它。我认为,就我而言,这与对不是tf.Variable
(self.s
)的属性的分配有关。但是,我已经尝试过将其设置为这样,或者在没有任何帮助的情况下使用add_weight
。我的图层如下:
class AdaCos(tf.keras.layers.Layer):
def __init__(self, n_classes, margin=None, logit_scale=None, **kwargs):
super().__init__(**kwargs)
self.n_classes = n_classes
self.s = math.sqrt(2)*math.log(n_classes-1)
def build(self, input_shape):
super().build(input_shape[0])
self.w = self.add_weight(name='weights',
shape=(input_shape[0][-1], self.n_classes),
initializer='glorot_uniform',
trainable=True)
@staticmethod
def get_median(v):
v = tf.reshape(v, [-1])
mid = v.get_shape()[0]//2 + 1
return tf.nn.top_k(v, mid).values[-1]
def call(self, inputs):
x, y = inputs
# normalize feature
x = tf.nn.l2_normalize(x, axis=1, name='normed_embd')
# normalize weights
w = tf.nn.l2_normalize(self.w, axis=0, name='normed_weights')
# dot product
logits = tf.matmul(x, w, name='cos_t')
# add margin
# clip logits to prevent zero division when backward
theta = tf.acos(tf.clip_by_value(logits, -1.0 + 1e-5, 1.0 - 1e-5))
B_avg = tf.where(tf.expand_dims(y, 1) < 1, tf.exp(self.s*logits), tf.zeros_like(logits))
B_avg = tf.reduce_mean(tf.reduce_sum(B_avg, axis=1), name='B_avg')
theta_class = tf.gather_nd(theta, tf.expand_dims(tf.cast(y, tf.int32), 1), 1, name='theta_class')
theta_med = self.get_median(theta_class)
with tf.control_dependencies([theta_med, B_avg]):
self.s = tf.math.log(B_avg) / tf.cos(tf.minimum(math.pi/4, theta_med))
out = tf.multiply(logits, self.s, 'arcface_logist')
return out
def compute_output_shape(self, input_shape):
return (None, self.n_classes)