如何在TensorFlow中保存特定变量?

时间:2017-03-28 11:34:27

标签: python tensorflow

我建立一个网络来测试保存模型。 这是我的代码:

import tensorflow as tf
import numpy as np
import time

dimensions=100
batch_size=128

def add_layer(inputs, in_size, out_size, activation_function=None):
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

def f(batch_size,val,dims):
  a = np.zeros(batch_size,dtype=np.int32)+val
  b = np.zeros((batch_size, dims))
  b[np.arange(batch_size), a] = 1
  return b

xs = tf.placeholder(tf.float32, [None, dimensions])
ys = tf.placeholder(tf.float32, [None, 43])

l1 = add_layer(xs, dimensions, 64, activation_function=None)
l2 = add_layer(l1, 64, 64, activation_function=tf.nn.sigmoid)
prediction = add_layer(l2, 64, 43, activation_function=None)

loss = tf.reduce_mean(tf.square(ys - prediction))
train_step = tf.train.AdamOptimizer(0.003).minimize(loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())


for step in range(100):
  start_time = time.time()
  X = f(batch_size=batch_size,val=step,dims=dimensions)
  y = np.random.rand(batch_size,43)
  sess.run(train_step, feed_dict={xs:X, ys:y})
  duration = time.time()-start_time
  if step%10 == 0:
    loss_value = sess.run(loss, feed_dict={xs: X, ys: y})
    format_str = ('step %d,loss=%5.2f (%.1f examples/sec;%.3f sec/batch)')
    print(format_str %(step,loss_value,batch_size/duration,float(duration)))      

saver = tf.train.Saver()
save_path = saver.save(sess, "./save_net.ckpt")
sess.close()    

将所有变量保存到“./save_net.ckpt”。

但我只想保存l1层的重量和偏差。 怎么做?

如何在TensorFlow中提取这些变量?

1 个答案:

答案 0 :(得分:0)

您应该查看tensorflow文档。 variables

特别关于选择保存和恢复哪些变量

的部分

在你的情况下

您应该将名称传递给创建权重和偏差的函数,以便声明为

Weights = tf.Variable(tf.random_normal([in_size, out_size]), name=weights_name)
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, name = biases_name)

然后

saver = tf.train.Saver({"l1_wieghts": "l1_weights_name", 
                        "l1_biases": "l1_biases_name", 
                         "l2_weights":"l2_weights_names", 
                         "l2_biases":"l2_biases_name"})