如何训练权重限制为特定值的神经网络?

时间:2019-06-16 22:53:38

标签: tensorflow

我正在尝试训练仅具有某些值的权重的网络。但是,我这样做的方式需要很长时间,例如对于MNIST上的3层完全连接网络,每个纪元5h。有更快的方法吗?

我正在使用tf.keras构建我的网络。我添加了一个自定义tf.constraint,它在更新权重时对可能的权重值列表进行二进制搜索。我从here找到了适合我的应用程序的二进制搜索代码。为了将二进制搜索功能应用于所有参数,我使用“ tf.map_fn”。

这是约束类:

from tensorflow.python.keras.constraints import Constraint
import tensorflow as tf

# binary search function
def find(weights, query, shape):
    vals = tf.map_fn(lambda x: weights[tf.argmin(tf.cast(x >= weights, dtype=tf.int32)[1:] - tf.cast(x >= weights, dtype=tf.int32)[:-1])], tf.reshape(query,[-1]))
    return tf.reshape(vals, shape)

class WeightQuantizeClip(Constraint):
    # weights parameter holds the possible weight values
    def __init__(self, weights = []):

        self.weights = tf.convert_to_tensor(weights)

    def __call__(self, p):
        p = find(self.weights, p, p.shape)
        return p

    def get_config(self):
        return {'name': self.__class__.__name__}

当我使用上述约束训练网络时,权重仅来自可能的权重值,但是训练时间却大大增加。没有二进制搜索功能,我的GPU被充分利用,但是当我训练二进制搜索功能时,利用率下降到2%。有人可以帮我吗?

1 个答案:

答案 0 :(得分:0)

从您的描述看来,剪辑操作的某些部分似乎在需要RAM-VRAM通信的CPU上执行,这非常慢。

但是,如果您尝试进行传统的NN量化,则实际上为此目的构建了一个完整的TF模块,您可能需要检查一下,也许它涵盖了您的用例。

https://www.tensorflow.org/api_docs/python/tf/quantization/quantize