我认为Tensorflow saver将按此处所述保存所有变量
如果未将任何参数传递给tf.train.Saver(),则保护程序 处理图中的所有变量。每个变量都保存在 创建变量时传递的名称。
https://www.tensorflow.org/programmers_guide/saved_model
但是,下面的代码中的变量epochCount似乎没有保存。此变量用于跟踪模型针对数据训练的总时间。
当我恢复图形时,它会重置为其初始值,而不是上次保存检查点时的初始值。
在我看来,它只是保存用于计算损失的变量。
这是我的代码。
这是我声明图形的地方:
graph = tf.Graph()
with graph.as_default():
valid_examples = np.array(random.sample(range(1, valid_window), valid_size)) #put inside graph to get new words each time
train_dataset = tf.placeholder(tf.int32, shape=[batch_size, cbow_window*2 ])
train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
valid_datasetSM = tf.constant(valid_examples, dtype=tf.int32)
epochCount = tf.get_variable( 'epochCount', initializer= 0) #to store epoch count to total # of epochs are known
embeddings = tf.get_variable( 'embeddings',
initializer= tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
softmax_weights = tf.get_variable( 'softmax_weights',
initializer= tf.truncated_normal([vocabulary_size, embedding_size],
stddev=1.0 / math.sqrt(embedding_size)))
softmax_biases = tf.get_variable('softmax_biases',
initializer= tf.zeros([vocabulary_size]), trainable=False )
embed = tf.nn.embedding_lookup(embeddings, train_dataset) #train data set is
embed_reshaped = tf.reshape( embed, [batch_size*cbow_window*2, embedding_size] )
segments= np.arange(batch_size).repeat(cbow_window*2)
averaged_embeds = tf.segment_mean(embed_reshaped, segments, name=None)
loss = tf.reduce_mean(
tf.nn.sampled_softmax_loss(weights=softmax_weights, biases=softmax_biases, inputs=averaged_embeds,
labels=train_labels, num_sampled=num_sampled, num_classes=vocabulary_size))
optimizer = tf.train.AdagradOptimizer(1.0).minimize(loss) #Original learning rate was 1.0
norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True))
normalized_embeddings = embeddings / norm
valid_embeddings = tf.nn.embedding_lookup(
normalized_embeddings, valid_dataset)
similarity = tf.matmul(valid_embeddings, tf.transpose(normalized_embeddings))
saver = tf.train.Saver()
如果我从检查点还原图形,则会还原嵌入和softmax_biases,但epochCount会重置为其初始值。 (请注意,我没有调用tf.global_variables_initializer()。run()行,这是导致在恢复检查点后错误地重置变量的常见原因)
这是运行图形的代码
num_steps = 1000001
with tf.Session(graph=graph) as session:
saver.restore(session, './checkpointsBook2VecCbowWindow2Downloaded/bookVec.ckpt' )
average_loss = 0
saveIteration = 1
for step in range(1, num_steps):
batch_data, batch_labels = generate_batch(
batch_size, cbow_window)
feed_dict = {train_dataset : batch_data, train_labels : batch_labels}
_, l = session.run([optimizer, loss], feed_dict=feed_dict)
if step % 20000 == 0:
recEpoch_indexA = epoch_index - recEpoch_indexA
epochCount = tf.add( epochCount, recEpoch_indexA, name=None )
recEpoch_indexA = epoch_index
save_path = saver.save(session, "checkpointsBook2VecCbowWindow2/bookVec.ckpt")
chptName = 'B2VCbowW2Embed256ckpt'+str(saveIteration)
zipfolder(chptName, 'checkpointsBook2VecCbowWindow2')
uploadModel.SetContentFile(chptName+".zip")
uploadModel.Upload()
print("Checkpoint uploaded to Google Drive")
saveIteration += 1
这是我用来打印出训练后保存在检查点中的所有变量的代码。我恢复了图形并打印出所有保存的变量。
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./MODEL/bookVec.ckpt.meta')
saver.restore(sess, './MODEL/bookVec.ckpt' )
for v in tf.get_default_graph().get_collection("variables"):
print('From variables collection ', v)
这是上面代码的输出
From variables collection <tf.Variable 'embeddings:0' shape=(10001, 256) dtype=float32_ref>
From variables collection <tf.Variable 'softmax_weights:0' shape=(10001, 256) dtype=float32_ref>
From variables collection <tf.Variable 'softmax_biases:0' shape=(10001,) dtype=float32_ref>
如图所示,epochCount尚未保存。
答案 0 :(得分:1)
将变量恢复为0的原因是因为它实际上从未更新过(即已正确恢复了)!您正在会话期间通过epochCount
调用覆盖tf.add
,该调用仅返回操作,而没有实际值。也就是说,变量(在Tensorflow上)是“孤立的”,并将永远保持为0。
您可以使用tf.assign
来更新变量。看起来可能像这样:
# where you define the graph
epochCount = tf.get_variable( 'epochCount', initializer= 0)
update_epoch = tf.assign(epochCount, epochCount + 1)
...
# after you launched the session
for step in range(1, num_steps):
if step % 20000 == 0:
sess.run(update_epoch)