将L2正则化添加到Tensorflow中的特定嵌入

时间:2018-02-12 12:58:15

标签: tensorflow

我正在构建一个像广泛和广泛的模型深度使用Tensorflow。对于离散特征,我首先将它们嵌入到向量空间中,我想知道如何在嵌入上添加L2规范化。

L2正则化运算符tf.nn.l2_loss接受嵌入张量作为输入,但我只想规范specific embeddings whose id appear in current batch of data,而不是整个矩阵。

1 个答案:

答案 0 :(得分:0)

只需使用specific embeddings whose id appear in current batch of data来计算正则化损失。

import tensorflow as tf

ids = sparse_tensor.values
uniq_ids, _ = tf.python.ops.array_ops.unique(ids)    
embedding_index_slices = tf.gather(large_embedding_variable, uniq_ids) 
regularization_loss = tf.nn.l2_loss(embedding_index_slices.values)

...

loss = train_loss + FLAGS.l2 * regularization_loss