如何在张量流中仅重用一些变量?

时间:2016-07-24 04:11:14

标签: tensorflow

我有一个tensorflow类,它有权重和文档嵌入。我会将它用于我的培训和验证。我的查询是,在tensorflow会话中,验证集可以仅重用我的训练中的权重而不是嵌入,并让它学习有效集的新文档嵌入。代码片段。

Class NewModel(Object):
  def __init__(self, is_training, vocabuary_size, embedding_size):
    self.X = tf.placeholder("float", [None, 300])  
    self.doc_int = tf.placeholder(tf.int32, shape=[None]) 

    self.embeddings=tf.get_variable("embedding", [vocabulary_size ,embedding_size],initializer=tf.random_uniform_initializer(-0.1, 0.1))
    self.embedval = tf.nn.embedding_lookup(self.embeddings ,self.doc_int)
    self.weights = tf.get_variable("weights",weight_shapeinitializer=tf.random_normal_initializer())
    biases = tf.get_variable("biases", bias_shape,initializer=tf.constant_initializer(0.0))
    # Some neural network with optimiser and loss that will train weight and embeddings..

with tf.Graph().as_default(), tf.Session() as sess:

  initializer = tf.random_uniform_initializer()
  with tf.variable_scope("foo", reuse=None, initializer=initializer):
    train = NewModel(is_training=True, vocabulary_size=4000,\
    embedding_size =50)
    with tf.variable_scope("foo", reuse=True,   initializer=initializer):
      valid = NewModel(is_training=False, vocabulary_size= 1000, embedding_size = 50)
 # Here is where I am confused. I want to use trained variable of weight but not embeddings and 
 want new embeddings to be trained for valid set.
  tf.initialize_all_variables().run()
 # will call some function to run epochs and stuff

使用不同的范围名称可能有所帮助,但仍然需要一些关于它的建议。或者可以在某处提及哪些变量可以重复使用。

1 个答案:

答案 0 :(得分:0)

我可能会重新组织NewModel类。

Class NewModel(Object):
    def __init__(self, vocabuary_size, embedding_size, initializer):
        self.X = tf.placeholder("float", [None, 300])  
        self.doc_int = tf.placeholder(tf.int32, shape=[None]) 
        self.vocabuary_size = vocabuary_size
        self.embedding_size = embedding_size
        self.initializer = initializer

    def initialize_embeddings(self):
        with tf.variable_scope("embed",initializer=initializer) as scope:
             self.embeddings=tf.get_variable("embedding", [self.vocabulary_size ,self.embedding_size],initializer=self.initializer)
             self.embedval = tf.nn.embedding_lookup(self.embeddings ,self.doc_int)
             scope.reuse_variable()

    def initialize_weights(self, weight_shape, biase_shape, initializer=initializer):
        with tf.variable_scope("weight", initializer=initializer) as scope:
             self.weights = tf.get_variable("weights",weight_shapeinitializer=self.initializer)
             biases = tf.get_variable("biases", bias_shape,initializer=tf.constant_initializer(0.0))
             scope.reuse_variable()

    def train_network(self):
         # Some neural network with optimiser and loss that will train weight and embeddings..

    def validate_network(self):
         # A function for the validation process

这样您就可以将嵌入初始化与权重和偏差初始化分开。这个新类的用法就像......

with tf.Graph().as_default(), tf.Session() as sess:

    initializer = tf.random_uniform_initializer()
    model = NewModel(vocabulary_size=4000, embedding_size =50, initializer=initializer) # construct a model instance
    model.initialize_weights(weight_shape, biase_shape) # initialize the weights and biases
    model.initialize_embeddings() # initialize embeddings
    model.train_network() # train the network
    # Before start validation process, re-initialize embeddings
    model.initialize_embeddings()
    model.validate_network()