我正在尝试训练仅具有某些值的权重的网络。但是,我这样做的方式需要很长时间,例如对于MNIST上的3层完全连接网络,每个纪元5h。有更快的方法吗?
我正在使用tf.keras构建我的网络。我添加了一个自定义tf.constraint,它在更新权重时对可能的权重值列表进行二进制搜索。我从here找到了适合我的应用程序的二进制搜索代码。为了将二进制搜索功能应用于所有参数,我使用“ tf.map_fn”。
这是约束类:
from tensorflow.python.keras.constraints import Constraint
import tensorflow as tf
# binary search function
def find(weights, query, shape):
vals = tf.map_fn(lambda x: weights[tf.argmin(tf.cast(x >= weights, dtype=tf.int32)[1:] - tf.cast(x >= weights, dtype=tf.int32)[:-1])], tf.reshape(query,[-1]))
return tf.reshape(vals, shape)
class WeightQuantizeClip(Constraint):
# weights parameter holds the possible weight values
def __init__(self, weights = []):
self.weights = tf.convert_to_tensor(weights)
def __call__(self, p):
p = find(self.weights, p, p.shape)
return p
def get_config(self):
return {'name': self.__class__.__name__}
当我使用上述约束训练网络时,权重仅来自可能的权重值,但是训练时间却大大增加。没有二进制搜索功能,我的GPU被充分利用,但是当我训练二进制搜索功能时,利用率下降到2%。有人可以帮我吗?
答案 0 :(得分:0)
从您的描述看来,剪辑操作的某些部分似乎在需要RAM-VRAM通信的CPU上执行,这非常慢。
但是,如果您尝试进行传统的NN量化,则实际上为此目的构建了一个完整的TF模块,您可能需要检查一下,也许它涵盖了您的用例。
https://www.tensorflow.org/api_docs/python/tf/quantization/quantize