我有以下情况:
我已经建立,训练并保存了我的网。现在,我正在尝试恢复网络并可视化权重矩阵。
我知道变量的所有名称,但是我没有为变量分配python标记以传递给会话进行评估。如何检索变量中的数据?
这是我的代码情况:
dataset_params = nn_params.mnist_dataset_params
design = nn_designs.mnist_net_A_design
## Build Housing Object
mnist_nn = nn_class.CNN(**dataset_params)
mnist_nn.build_net(design['design'])
mnist_nn.__setattr__('saved_path',saved_model)
mnist_nn_epoch_file = saved_model+'_epochs_completed.txt'
mnist_nn.__setattr__('epoch_file',mnist_nn_epoch_file)
# evaluate weight variables
session = tf.Session()
saver = tf.train.Saver()
session.run(tf.initialize_all_variables())
saver.restore(session,saved_model)
session.close()
为了拉出权重,我应该传递给会话? (示例权重名称为:' conv_w_1')?
答案 0 :(得分:6)
您可以使用tf.get_collection()
查找方法获取所需的变量:
weight_var = tf.get_collection(tf.GraphKeys.VARIABLES, "conv_w_1")[0]
weight_var_value = session.run(weight_var)