如何从tf.contrib.layers.batch_norm中提取mov_mean,mov_var,scale和shift,然后手动插入test_feed?

时间:2017-06-13 19:50:30

标签: tensorflow

使用Batch-norm,我想提取moving_meanmoving_variance,从tf.contrib.layers.batch_norm缩放和移位数组(通过培训收集)并将其存储在单独的名单。

之后我喜欢在我的预训练测试模型中使用所有4个参数,我通过自己的方法恢复(通过pickle,所以没有.ckpt文件)。

tf.contrib.layers.batch_norm中提取所有这些值并将其手动反馈到测试模型中是否可行?

到目前为止,我研究过:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss)

但我不确定这些值是mov_mean和mov_var还是scale和shift?无论如何,仍然缺少两个参数。

有什么建议吗? (如果没有,我去手动实施)

1 个答案:

答案 0 :(得分:0)

我找到了两种方法:

(1)具体解决方案:从variables_collections中获取所有4个变量(你需要传递一个字符串列表!):

bnorm = tf.contrib.layers.batch_norm(predicted_outputs,center=True,scale=True,
        is_training=True, updates_collections=tf.GraphKeys.UPDATE_OPS,variables_collections=['vars2'], outputs_collections='vars', decay=0.999, zero_debias_moving_mean=False)

然后初始化会话并打印出来:

session = tf.Session()
session.run(tf.global_variables_initializer())
print("variables_collections")
print("mov_mean, mov_var,beta, gamma")
all_vars = tf.get_collection('vars2')
all_vars4 = session.run(all_vars,feed_dict=feedvalid)
print(len(all_vars4))
print(all_vars4[0])
print(all_vars4[1])
print(all_vars4[2])
print(all_vars4[3])

(2)一般解决方案:使用variable_scope(这总是适用于所有类型的函数):

with tf.variable_scope("Bnorm"):
        bnorm = tf.contrib.layers.batch_norm(predicted_outputs,center=True,scale=True,
        is_training=True, decay=0.999, zero_debias_moving_mean=False)
首先是

和init会话,然后像上面那样打印出来:

print("Global: tf.GraphKeys")
print("mov_mean, mov_var,beta, gamma")
all_vars10 = tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES, scope="Bnorm")
all_vars11 = session.run(all_vars10,feed_dict=feedvalid)
print(len(all_vars11))
print(all_vars11)

更新:为测试(推理)指定值(来自培训)

beta_ass = all_vars10[0].assign(np.ones(all_vars10[0].shape) * 7) #get the tf.variable beta and assign a 7 to it
print("assigned: beta with 7")
print (session.run(beta_ass)) #sess.run is needed to really assign values to matrices!!

现在执行上述解决方案之一并检查该值是否已真正分配。另外设置is_training=False以保持测试的所有值都是正确的。