如何有效地分割张量,以使初始图形加载时间不会那么长

时间:2019-06-21 21:06:31

标签: python tensorflow keras

我正在为tf.keras创建一个自定义层,它可以在推理/训练中正常工作,但是初始加载时间非常慢。它显然来自我用于在call()中拆分数据的嵌套for循环,但是我不确定如何进行矢量化或只是加快此过程。任何建议都是好的!谢谢!

我尝试使用tf.dynamic_partition,但我不确定仅阅读tensorflow网站上的文档即可完全理解该方法。

class AttentionLayer(tf.keras.layers.Layer):
  def __init__(self, output_units):
    super(AttentionLayer, self).__init__()
    self.output_units = output_units

  def build(self, input_shape):
    self.len = input_shape[1]
    self.cells = 100

    if len(input_shape) == 3:
        self.c = input_shape[2]
    else:
        self.c = 1
    self.WQs = []
    self.WKs = []
    self.WA = self.add_variable("WA", [int(self.len - 1), 1],initializer=tf.glorot_uniform_initializer)


    for idx in range(self.cells):
        WQ = self.add_variable("WQ" + str(idx), [self.c , self.output_units]
                              ,initializer=tf.glorot_uniform_initializer)
        WK = self.add_variable("WK" + str(idx), [self.c , self.output_units]
                              ,initializer=tf.glorot_uniform_initializer)


        self.WQs.append(WQ)
        self.WKs.append(WK)

  def call(self, input):
    attention = []
    array = tf.reshape(input, [-1, self.len, self.c])
    batch_size = tf.shape(input)[0]
    for idx in range(self.cells):
        print(str(idx), end="\r")
        Q = tf.reshape(array[:, idx], [-1, 1, self.c])
        #print("Q: ", Q)
        context_list = []
        for cdx in range(self.len):
            if idx != cdx:
                context_list.append(tf.reshape(array[:, cdx], [-1, 1, self.c]))
        K = tf.concat(context_list, 1)
        #print("K: ", K)

        WQ_expand = tf.expand_dims(self.WQs[idx], axis=0)
        WK_expand = tf.expand_dims(self.WKs[idx], axis=0)

        WQ_tile = tf.tile(WQ_expand, [batch_size, 1, 1])
        WK_tile = tf.tile(WK_expand, [batch_size, 1, 1])

        Q = tf.matmul(Q, WQ_tile)
        K = tf.matmul(K, WK_tile)
        a = tf.nn.sigmoid(tf.matmul(Q,tf.reshape(K, [-1, self.output_units, int(self.len - 1)]))/27.5)
        #print("a: ", a)
        attention.append(a)

    A = tf.concat(attention,1)
    #print("A: ", A)
    WA_expand = tf.expand_dims(self.WA, axis=0)
    WA_tile = tf.tile(WA_expand, [batch_size, 1, 1])
    Z = tf.reshape(tf.nn.sigmoid(tf.matmul(A,WA_tile)), [-1, self.cells, 1])
    #print("Z: ", Z)
    return Z

该图层采用3D输入(批处理大小,宽度*高度,通道)..打印语句应显示正在初始化的图层。

0 个答案:

没有答案