我正在为tf.keras创建一个自定义层,它可以在推理/训练中正常工作,但是初始加载时间非常慢。它显然来自我用于在call()中拆分数据的嵌套for循环,但是我不确定如何进行矢量化或只是加快此过程。任何建议都是好的!谢谢!
我尝试使用tf.dynamic_partition,但我不确定仅阅读tensorflow网站上的文档即可完全理解该方法。
class AttentionLayer(tf.keras.layers.Layer):
def __init__(self, output_units):
super(AttentionLayer, self).__init__()
self.output_units = output_units
def build(self, input_shape):
self.len = input_shape[1]
self.cells = 100
if len(input_shape) == 3:
self.c = input_shape[2]
else:
self.c = 1
self.WQs = []
self.WKs = []
self.WA = self.add_variable("WA", [int(self.len - 1), 1],initializer=tf.glorot_uniform_initializer)
for idx in range(self.cells):
WQ = self.add_variable("WQ" + str(idx), [self.c , self.output_units]
,initializer=tf.glorot_uniform_initializer)
WK = self.add_variable("WK" + str(idx), [self.c , self.output_units]
,initializer=tf.glorot_uniform_initializer)
self.WQs.append(WQ)
self.WKs.append(WK)
def call(self, input):
attention = []
array = tf.reshape(input, [-1, self.len, self.c])
batch_size = tf.shape(input)[0]
for idx in range(self.cells):
print(str(idx), end="\r")
Q = tf.reshape(array[:, idx], [-1, 1, self.c])
#print("Q: ", Q)
context_list = []
for cdx in range(self.len):
if idx != cdx:
context_list.append(tf.reshape(array[:, cdx], [-1, 1, self.c]))
K = tf.concat(context_list, 1)
#print("K: ", K)
WQ_expand = tf.expand_dims(self.WQs[idx], axis=0)
WK_expand = tf.expand_dims(self.WKs[idx], axis=0)
WQ_tile = tf.tile(WQ_expand, [batch_size, 1, 1])
WK_tile = tf.tile(WK_expand, [batch_size, 1, 1])
Q = tf.matmul(Q, WQ_tile)
K = tf.matmul(K, WK_tile)
a = tf.nn.sigmoid(tf.matmul(Q,tf.reshape(K, [-1, self.output_units, int(self.len - 1)]))/27.5)
#print("a: ", a)
attention.append(a)
A = tf.concat(attention,1)
#print("A: ", A)
WA_expand = tf.expand_dims(self.WA, axis=0)
WA_tile = tf.tile(WA_expand, [batch_size, 1, 1])
Z = tf.reshape(tf.nn.sigmoid(tf.matmul(A,WA_tile)), [-1, self.cells, 1])
#print("Z: ", Z)
return Z
该图层采用3D输入(批处理大小,宽度*高度,通道)..打印语句应显示正在初始化的图层。