在fully_connected中访问变量和偏移张量

时间:2017-12-17 14:25:06

标签: machine-learning tensorflow neural-network artificial-intelligence

我想在训练步骤中将变量和偏差张量保存为检查点。我使用了来自tf.contrib.layers的fully_connected()来实现几个完全连接的层。为此,我需要提取那些完全连接的图层的变量和偏移张量。怎么做?

2 个答案:

答案 0 :(得分:4)

Just to mention that:

  • There is no need to extract weights and biases just to save them. For tf.layers or tf.contrib.layers, if trainable is set to True, the weights and biases are added to GraphKeys.TRAINABLE_VARIABLES, which is a subset of GraphKeys.GLOBAL_VARIABLES. Thus, if you use saver = tf.train.Saver(var_list=tf.global_variables()) and saver.save(sess, save_path, global_step) at some point, the weights and biases will be saved.
  • In cases you really have to extract variables, one way would be to use tf.get_variable or tf.get_default_graph().get_tensor_by_name with the correct variable name, as mentioned by the other answer.
  • You might have noticed TensorFlow classes such as tf.layer.Dense and tf.layers.Conv2D. Once built, they have weights / variables methods that return the weight and bias tensors.

答案 1 :(得分:1)

tf.trainable_variables()将为您提供网络中可训练的所有变量的列表。使用variable_scope和name_scope可以做得更好,如下所述:How to get weights from tensorflow fully_connected

In [1]: import tensorflow as tf

In [2]: a1 = tf.get_variable(name='a1', shape=(1,2), dtype=tf.float32)

In [3]: fc = tf.contrib.layers.fully_connected(a1, 4)

In [4]: sess = tf.Session()
2017-12-17 21:09:18.127498: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-12-17 21:09:18.127554: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-12-17 21:09:18.127578: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-12-17 21:09:18.127598: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-12-17 21:09:18.127618: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.

In [5]: tf.trainable_variables()
Out[5]: 
[<tf.Variable 'a1:0' shape=(1, 2) dtype=float32_ref>,
 <tf.Variable 'fully_connected/weights:0' shape=(2, 4) dtype=float32_ref>,
 <tf.Variable 'fully_connected/biases:0' shape=(4,) dtype=float32_ref>]

In [6]: for var in tf.trainable_variables():
   ...:     if 'weights' in var.name or 'biases' in var.name:
   ...:         print(var)
   ...:         
<tf.Variable 'fully_connected/weights:0' shape=(2, 4) dtype=float32_ref>
<tf.Variable 'fully_connected/biases:0' shape=(4,) dtype=float32_ref>

In [7]: