我想在训练步骤中将变量和偏差张量保存为检查点。我使用了来自tf.contrib.layers的fully_connected()来实现几个完全连接的层。为此,我需要提取那些完全连接的图层的变量和偏移张量。怎么做?
答案 0 :(得分:4)
Just to mention that:
True
, the weights and biases are added to GraphKeys.TRAINABLE_VARIABLES
, which is a subset of GraphKeys.GLOBAL_VARIABLES
. Thus, if you use saver = tf.train.Saver(var_list=tf.global_variables())
and saver.save(sess, save_path, global_step)
at some point, the weights and biases will be saved.tf.get_variable
or tf.get_default_graph().get_tensor_by_name
with the correct variable name, as mentioned by the other answer.tf.layer.Dense
and tf.layers.Conv2D
. Once built, they have weights
/ variables
methods that return the weight and bias tensors. 答案 1 :(得分:1)
tf.trainable_variables()将为您提供网络中可训练的所有变量的列表。使用variable_scope和name_scope可以做得更好,如下所述:How to get weights from tensorflow fully_connected
In [1]: import tensorflow as tf
In [2]: a1 = tf.get_variable(name='a1', shape=(1,2), dtype=tf.float32)
In [3]: fc = tf.contrib.layers.fully_connected(a1, 4)
In [4]: sess = tf.Session()
2017-12-17 21:09:18.127498: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-12-17 21:09:18.127554: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-12-17 21:09:18.127578: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-12-17 21:09:18.127598: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-12-17 21:09:18.127618: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
In [5]: tf.trainable_variables()
Out[5]:
[<tf.Variable 'a1:0' shape=(1, 2) dtype=float32_ref>,
<tf.Variable 'fully_connected/weights:0' shape=(2, 4) dtype=float32_ref>,
<tf.Variable 'fully_connected/biases:0' shape=(4,) dtype=float32_ref>]
In [6]: for var in tf.trainable_variables():
...: if 'weights' in var.name or 'biases' in var.name:
...: print(var)
...:
<tf.Variable 'fully_connected/weights:0' shape=(2, 4) dtype=float32_ref>
<tf.Variable 'fully_connected/biases:0' shape=(4,) dtype=float32_ref>
In [7]: