如何扩展tf.nn.embedding_lookup_sparse

时间:2016-11-16 09:56:23

标签: tensorflow sharding partition

我正在尝试构建一个非常大的稀疏模型(例如LR,如果只有一个嵌入层),输入维数可以大到100000000,并且样本非常稀疏,平均非零值数大约是100.由于权重非常大,我们必须将其分区并分发到不同的服务器上。这是代码:

weights = tf.get_variable("weights",                                        
                          weights_shape,                                    
                          partitioner=tf.fixed_size_partitioner(num_shards, axis=0), 
                          initializer=tf.truncated_normal_initializer(stddev=0.1))
biases = tf.get_variable("biases",                                          
                         biases_shape,                                      
                         initializer=tf.truncated_normal_initializer(stddev=0.1))

result = tf.nn.embedding_lookup_sparse(weights,                              
                                       ids,                                  
                                       values,                               
                                       partition_strategy="div",             
                                       combiner="sum") + biases

This is the generated graph for this op

从图表中,embedding_lookup_sparse只是简单地表示分片权重,会导致大量不必要的网络流量。这看起来很愚蠢。可以合理的方法是在本地独立查找每个分片,然后发送回调查结果并聚合它们。通过这种方式,流量大大减少。我想知道TensorFlow是否支持这种模式?当然,我可以通过自定义代码来实现这一目标。

编辑内容

符合预期的解决方案:

num_shards = 2
weights = []
assert weights_shape[0] % num_shards == 0
for i in range(0, num_shards):
  weights_i = tf.get_variable("weights-%02d" % i,
                              [weights_shape[0]/num_shards] + weights_shape[1:],
                              initializer=tf.truncated_normal_initializer(stddev=0.1))
  weights.append(weights_i)                                                 
biases = tf.get_variable("biases",
                         biases_shape,
                         initializer=tf.truncated_normal_initializer(stddev=0.1))

result = tf.nn.embedding_lookup_sparse(weights,                              
                                       ids,                                  
                                       values,                               
                                       partition_strategy="div",             
                                       combiner="sum") + biases

0 个答案:

没有答案