我正在尝试在训练后从模型中提取权重。这是一个玩具示例
import tensorflow as tf
import numpy as np
X_ = tf.placeholder(tf.float64, [None, 5], name="Input")
Y_ = tf.placeholder(tf.float64, [None, 1], name="Output")
X = ...
Y = ...
with tf.name_scope("LogReg"):
pred = fully_connected(X_, 1, activation_fn=tf.nn.sigmoid)
loss = tf.losses.mean_squared_error(labels=Y_, predictions=pred)
training_ops = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(200):
sess.run(training_ops, feed_dict={
X_: X,
Y_: Y
})
if (i + 1) % 100 == 0:
print("Accuracy: ", sess.run(accuracy, feed_dict={
X_: X,
Y_: Y
}))
# Get weights of *pred* here
我查看了Get weights from tensorflow model 和docs,但无法找到检索权重值的方法。
所以在玩具示例中,假设X_具有形状(1000,5),我怎样才能获得1层权重中的5个值
答案 0 :(得分:11)
您的代码中存在一些需要修复的问题:
1-您需要在以下行使用variable_scope
代替name_scope
(请参阅TensorFlow文档了解它们之间的差异):
with tf.name_scope("LogReg"):
2-为了能够在代码中稍后检索变量,您需要知道它的名称。因此,您需要为感兴趣的变量指定一个名称(如果您不支持该变量,则会分配一个默认值,但是您需要确定它是什么!):
pred = tf.contrib.layers.fully_connected(X_, 1, activation_fn=tf.nn.sigmoid, scope = 'fc1')
现在让我们看看上述修复如何帮助我们获取变量的值。每一层都有两种类型的变量:权重和偏差。在以下代码段(您的修改版本)中,我将仅展示如何检索完全连接层的权重:
X_ = tf.placeholder(tf.float64, [None, 5], name="Input")
Y_ = tf.placeholder(tf.float64, [None, 1], name="Output")
X = np.random.randint(1,10,[10,5])
Y = np.random.randint(0,2,[10,1])
with tf.variable_scope("LogReg"):
pred = tf.fully_connected(X_, 1, activation_fn=tf.nn.sigmoid, scope = 'fc1')
loss = tf.losses.mean_squared_error(labels=Y_, predictions=pred)
training_ops = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
with tf.Session() as sess:
all_vars= tf.global_variables()
def get_var(name):
for i in range(len(all_vars)):
if all_vars[i].name.startswith(name):
return all_vars[i]
return None
fc1_var = get_var('LogReg/fc1/weights')
sess.run(tf.global_variables_initializer())
for i in range(200):
_,fc1_var_np = sess.run([training_ops,fc1_var], feed_dict={
X_: X,
Y_: Y
})
print fc1_var_np
答案 1 :(得分:0)
试试这个:
with tf.Session() as sess:
last_check = tf.train.latest_checkpoint(tf_data)
saver = tf.train.import_meta_graph(last_check+'.meta')
saver.restore(sess,last_check)
######
Model_variables = tf.GraphKeys.MODEL_VARIABLES
Global_Variables = tf.GraphKeys.GLOBAL_VARIABLES
######
all_vars = tf.get_collection(Model_variables)
# print (all_vars)
for i in all_vars:
print (str(i) + ' --> '+ str(i.eval()))
我明白了:
<tf.Variable 'linear/linear_model/DOLocationID/weights/part_0:0' shape=(1, 1) dtype=float32_ref> --> [[-0.00912262]]
<tf.Variable 'linear/linear_model/PULocationID/weights/part_0:0' shape=(1, 1) dtype=float32_ref> --> [[ 0.00573495]]
<tf.Variable 'linear/linear_model/passenger_count/weights/part_0:0' shape=(1, 1) dtype=float32_ref> --> [[-0.07072949]]
<tf.Variable 'linear/linear_model/trip_distance/weights/part_0:0' shape=(1, 1) dtype=float32_ref> --> [[ 2.59973669]]
<tf.Variable 'linear/linear_model/bias_weights/part_0:0' shape=(1,) dtype=float32_ref> --> [ 4.27982235]