不同名称范围内的Tensorflow重用变量

时间:2017-11-18 02:39:27

标签: python tensorflow recurrent-neural-network

我在不同的名称范围内遇到了重用变量的问题。下面的代码将源嵌入和目标嵌入分成两个不同的空间,我想要做的是将源和目标放在同一个空间中,重用查找表中的变量。

''' Applying bidirectional encoding for source-side inputs and first-word decoding.
'''
def decode_first_word(self, source_vocab_id_tensor, source_mask_tensor, scope, reuse):
    with tf.name_scope('Word_Embedding_Layer'):
        with tf.variable_scope('Source_Side'):
            source_embedding_tensor = self._src_lookup_table(source_vocab_id_tensor)
    with tf.name_scope('Encoding_Layer'):
        source_concated_hidden_tensor = self._encoder.get_biencoded_tensor(\
            source_embedding_tensor, source_mask_tensor)
    with tf.name_scope('Decoding_Layer_First'):
        rvals = self.decode_next_word(source_concated_hidden_tensor, source_mask_tensor, \
            None, None, None, scope, reuse)
    return rvals + [source_concated_hidden_tensor]


''' Applying one-step decoding.
'''
def decode_next_word(self, enc_concat_hidden, src_mask, cur_dec_hidden, \
                            cur_trg_wid, trg_mask=None, scope=None, reuse=False, \
                            src_side_pre_act=None):
    with tf.name_scope('Word_Embedding_Layer'):
        with tf.variable_scope('Target_Side'):
            cur_trg_wemb = None 
            if None == cur_trg_wid:
                pass
            else:
                cur_trg_wemb = self._trg_lookup_table(cur_trg_wid)

我想将它们 设为如下 ,因此整个图表中只会有一个嵌入节点:

def decode_first_word_shared_embedding(self, source_vocab_id_tensor, source_mask_tensor, scope, reuse):
    with tf.name_scope('Word_Embedding_Layer'):
        with tf.variable_scope('Bi_Side'):
            source_embedding_tensor = self._bi_lookup_table(source_vocab_id_tensor)
    with tf.name_scope('Encoding_Layer'):
        source_concated_hidden_tensor = self._encoder.get_biencoded_tensor(\
            source_embedding_tensor, source_mask_tensor)
    with tf.name_scope('Decoding_Layer_First'):
        rvals = self.decode_next_word_shared_embedding(source_concated_hidden_tensor, source_mask_tensor, \
            None, None, None, scope, reuse)
    return rvals + [source_concated_hidden_tensor]

def decode_next_word_shared_embedding(self, enc_concat_hidden, src_mask, cur_dec_hidden, \
                            cur_trg_wid, trg_mask=None, scope=None, reuse=False, \
                            src_side_pre_act=None):
    with tf.name_scope('Word_Embedding_Layer'):            
        cur_trg_wemb = None 
        if None == cur_trg_wid:
            pass
        else:
            with tf.variable_scope('Bi_Side'):
                cur_trg_wemb = self._bi_lookup_table(cur_trg_wid)

如何实现这一目标?

2 个答案:

答案 0 :(得分:1)

我通过使用字典来解决它,以保存嵌入的权重矩阵。来自https://www.tensorflow.org/versions/r0.12/how_tos/variable_scope/

的提示

答案 1 :(得分:1)

解决方案之一是保存variable_scope实例并重用它。


def decode_first_word_shared_embedding(self, source_vocab_id_tensor, source_mask_tensor, scope, reuse):
    with tf.name_scope('Word_Embedding_Layer'):
        with tf.variable_scope('Bi_Side'):
            source_embedding_tensor = self._bi_lookup_table(source_vocab_id_tensor)
            shared_variable_scope = tf.get_variable_scope()

    with tf.name_scope('Encoding_Layer'):
        source_concated_hidden_tensor = self._encoder.get_biencoded_tensor(\
            source_embedding_tensor, source_mask_tensor)
    with tf.name_scope('Decoding_Layer_First'):
        rvals = self.decode_next_word_shared_embedding(source_concated_hidden_tensor, source_mask_tensor, \
            None, None, None, scope, reuse)
    return rvals + [source_concated_hidden_tensor], 

def decode_next_word_shared_embedding(self, enc_concat_hidden, src_mask, cur_dec_hidden, shared_variable_scope, \
                            cur_trg_wid, trg_mask=None, scope=None, reuse=False, \
                            src_side_pre_act=None):
    with tf.variable_scope('Target_Side'):           
        cur_trg_wemb = None 
        if None == cur_trg_wid:
            pass
        else:
            with tf.variable_scope(shared_variable_scope, reuse=True):
                cur_trg_wemb = self._bi_lookup_table(cur_trg_wid)

这是我的演示代码:

with tf.variable_scope('Word_Embedding_Layer'):
    with tf.variable_scope('Bi_Side'):
        v = tf.get_variable('bi_var', [1], dtype=tf.float32)
        reuse_scope = tf.get_variable_scope()
with tf.variable_scope('Target_side'):
    # some other codes.
    with tf.variable_scope(reuse_scope, reuse=True):
        w = tf.get_variable('bi_var', [1], dtype=tf.float32)
print(v.name)
print(w.name)
assert v==w

Output:
Word_Embedding_Layer/Bi_Side/bi_var:0
Word_Embedding_Layer/Bi_Side/bi_var:0