如何加快自定义卷积?

时间:2019-08-22 14:54:56

标签: python tensorflow optimization convolution

我正在尝试实现一个自定义卷积函数,该函数会挑选出每个乘法实例,并将错误随机应用于被乘数。我目前正在尝试通过将滤镜和图像的元素乘以它在乘法中使用的次数来实现此目的,对每个项应用误差,进行逐元素相乘,然后求和。

但是,对于大输入量(如VGG16中的情况),此功能会变得非常慢,尤其是当我必须一次将其应用于多个过滤器时。即使没有随机为被乘数生成错误的部分,该函数也比张量流中的实现慢几个数量级。

代码附在下面。希望能帮助您加快速度。

def bit_error_int_convolve(kernel,input,error_rate):
    """
    2D Convolution for 8 bit integer image and kernel with bit flip error rate of error_rate
    :param kernel: 8 bit integer tensor of shape (filter size (row), filter size (col), channel size)
    :param input: 8 bit integer tensor of shape (batch size, image size (row), image size (col), channel size)
    :param error_rate: Rate of bit flip errors
    :return: 32 bit float tensor of shape (batch size, output image size (row), output image size (col))
    """

    kernel_shape = tf.shape(kernel)
    input_shape = tf.shape(input)

    # Copies kernel for number of times it is used to rank 4 tensor of shape
    # (batch size, output image size (row), output image size (col), filter size (row), filter size (col), channel number)
    mK = tf.reshape(kernel,(kernel_shape[0]*kernel_shape[1]*kernel_shape[2],))
    mK = tf.tile(mK, [input_shape[0]*(input_shape[1]-kernel_shape[0]+1)*(input_shape[2]-kernel_shape[1]+1)])
    mK = tf.reshape(mK,(input_shape[0],input_shape[1]-kernel_shape[0]+1,input_shape[2]-kernel_shape[1]+1,kernel_shape[0],kernel_shape[1],kernel_shape[2]))

    # Reshapes and copies elements of image to rank 4 tensor of shape
    # (batch size, output image size (row), output image size (col), filter size (row), filter size (col), channel number)
    rep = input[:,:kernel_shape[0],:kernel_shape[1],:]

    for k in range((tf.shape(mK)[1])*(tf.shape(mK)[2])-1):
        j = (k+1)%(tf.shape(mK)[2])
        i = (k-j+1)//(tf.shape(mK)[1])
        thing = input[:,i:i+kernel_shape[0],j:j+kernel_shape[1],:]
        rep = tf.concat([rep,thing],0)

    rep = tf.reshape(rep, (tf.shape(mK)[1], tf.shape(mK)[2], tf.shape(mK)[0], tf.shape(mK)[3], tf.shape(mK)[4], tf.shape(mK)[5]))
    rep = tf.transpose(rep, perm=[2, 0, 1, 3, 4, 5])

    # Applies bit flip errors to multiplicands with rate error_rate
    # Slow even if left out
    # if error_rate!=0:
    #     rep = bit_errors(rep,error_rate)
    #     mK = bit_errors(mK,error_rate)

    # Elementwise multiplication followed by summation to find output image
    rep = K.cast(rep,dtype=tf.float32)
    mK = K.cast(mK,dtype=tf.float32)
    output = rep*mK
    output = K.sum(output, axis=(3, 4, 5))

    return output

0 个答案:

没有答案