Tensorflow保存/恢复批量规范

时间:2017-08-21 15:30:06

标签: tensorflow batch-normalization

我在Tensorflow中训练了一个具有批量规范的模型。我想保存模型并将其恢复以供进一步使用。批量规范由

完成
True

其中阶段在培训期间为False,在测试期间为saver = tf.train.Saver() saver.save(sess, savedir + "ckpt")

似乎只是简单地调用

Attempting to use uninitialized value batch_normalization_585/beta

不能很好地工作,因为当我恢复模型时,它首先说成功恢复。如果我只在图中运行一个节点,它也会说CREATE FUNCTION [JACKINABOX](@TextToUpdate varchar(30), @FilterId int) RETURNS varchar(30) AS BEGIN DECLARE @Keyword varchar(30) DECLARE LonelyCursor CURSOR FOR SELECT Keyword FROM ReplacementInformation WHERE Id = @FilterId OPEN LonelyCursor ; FETCH NEXT FROM LonelyCursor INTO @Keyword WHILE @@FETCH_STATUS = 0 -- While there still remains keywords to process. BEGIN WHILE 1 = 1 -- Not sure, but I think this nested loop can be unlooped if [FETCH NEXT] was cut & pasted to replace [BREAK]. BEGIN IF(CHARINDEX(@Keyword, @TextToUpdate) = 0) BREAK -- If cannot find current keyword anymore, move on to next keyword. ELSE -- Otherwise, update text then check again for same keyword. SET @TextToUpdate = REPLACE(@TextToUpdate, @Keyword, CONCAT('Replaced_', @Keyword)) END FETCH NEXT FROM LonelyCursor INTO @Keyword END CLOSE LonelyCursor ; DEALLOCATE LonelyCursor RETURN @TextToUpdate END 。这是否与正确保存模型或我错过的其他内容有关?

2 个答案:

答案 0 :(得分:6)

我还有"尝试使用未初始化的值batch_normalization_585 / beta"错误。这是因为通过像这样用空括号声明保护程序:

         saver = tf.train.Saver() 

保护程序将保存tf.trainable_variables()中包含的变量,这些变量不包含批量标准化的移动平均值。要将这些变量包含在保存的ckpt中,您需要执行以下操作:

         saver = tf.train.Saver(tf.global_variables())

保存所有变量,因此非常耗费内存。或者您必须识别具有移动平均值或方差的变量,并通过声明它们来保存它们:

         saver = tf.train.Saver(tf.trainable_variables() + list_of_extra_variables)

答案 1 :(得分:3)

不确定是否需要解释,但以防万一(以及其他潜在的观众)。

每当您在TensorFlow中创建操作时,都会向图中添加一个新节点。图中没有两个节点可以具有相同的名称。您可以定义您创建的任何节点的名称,但是如果您没有给出名称,TensorFlow将以确定的方式为您选择一个(也就是说,不是随机的,而是始终使用相同的序列)。如果添加两个数字,它可能是Add,但如果再进行一次添加,因为没有两个节点可以具有相同的名称,它可能类似于Add_2。在图表中创建节点后,其名称将无法更改。许多函数依次创建多个子节点;例如,tf.layers.batch_normalization会创建一些内部变量betagamma

保存和恢复按以下方式工作:

  1. 您可以创建一个表示所需模型的图表。此图表包含将由保护程序保存的变量。
  2. 您可以使用该图表初始化,训练或执行任何操作,并为模型中的变量分配一些值。
  3. 您可以在保护程序上调用save,以便将变量的值保存到文件中。
  4. 现在,您在另一个图形中重新创建模型(它可以是完全不同的Python会话,也可以只是与第一个图形共存的另一个图形)。必须以与第一个模型完全相同的方式创建模型。
  5. 您在保护程序上调用restore以检索变量的值。
  6. 为了使其正常工作,第一个和第二个图表中变量的名称必须完全相同

    在您的示例中,TensorFlow抱怨变量batch_normalization_585/beta。您似乎已在同一图表中调用tf.layers.batch_normalization近600次,因此您可以使用多个beta变量。我怀疑你确实需要这么多,所以我猜你只是在试验API并最终获得了那么多副本。

    以下是应该有用的草稿:

    import tensorflow as tf
    
    def make_model():
        input = tf.placeholder(...)
        phase = tf.placeholder(...)
        input_norm = tf.layers.batch_normalization(input, training=phase))
        # Do some operations with input_norm
        output = ...
        saver = tf.train.Saver()
        return input, output, phase, saver
    
    # We work with one graph first
    g1 = tf.Graph()
    with g1.as_default():
        input, output, phase, saver = make_model()
        with tf.Session() as sess:
            # Do your training or whatever...
            saver.save(sess, savedir + "ckpt")
    
    # We work with a second different graph now
    g2 = tf.Graph()
    with g2.as_default():
        input, output, phase, saver = make_model()
        with tf.Session() as sess:
            saver.restore(sess, savedir + "ckpt")
            # Continue using your model...
    

    同样,典型的情况是并没有两个图并排,而是有一个图,然后在另一个Python会话中重新创建它,但最后两个都是相同的。重要的是,在两种情况下,模型都以相同的方式创建(因此具有相同的节点名称)。