在TF v1.3
中,我可以使用tf.train.import_meta_graph
和restore(sess,tf.train.latest_checkpoint
恢复模型的元数据和权重。运行sess.run()
会给出张量值。
我的问题是如何将新值分配(重写)到张量值并将其保存回模型中以便进一步处理。比如说,我对给定的图层有这些值:
>>>(sess.run('MobilenetV1/Conv2d_0/weights:0'))
...
1.55560642e-01, 1.29789323e-01, 2.59163193e-02,
8.00046027e-02, 4.73752208e-02, -5.41094005e-01,
-8.93476382e-02, -9.48717445e-02]]]], dtype=float32)
如何使用tf.assign()为最后8个打印值指定不同的值并将其保存回检查点。
答案 0 :(得分:1)
也许......这样可以为你效劳:
<section>section</section>
作为示例,我使用import tensorflow as tf
import numpy as np
tf.reset_default_graph()
with tf.variable_scope('scope1') as scope:
w = tf.get_variable('w', shape=[4,4])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
np_w = sess.run(w)
print(np_w)
np_w[2:,2:] = np.ones((2,2))
print(" . '"*10)
print(np_w)
print(" . '"*10)
with tf.variable_scope(scope, reuse=True):
v=tf.get_variable('w')
sess.run(tf.assign(w, np_w))
print(sess.run(w))
创建了一个随机 4x4 矩阵,并将子矩阵重新分配给 1 。希望这会有所帮助。
修改强>
从您保存的模型中访问变量 w ,然后为其分配一个新值:
get_variable
现在你可以按照我上面的建议继续攻击 w 。 w 应与恢复模型中的变量具有相同的名称,例如: w = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == 'w:0'][0]
。