Tensorflow saver.save()无法正常工作

时间:2019-02-12 11:00:01

标签: tensorflow deep-learning google-colaboratory

调用saver.save()可以在Google colab上保存四个文件,但是在我的本地计算机上,我只是得到一个检查点文件。 (使用macOS High Sierra)

def train(model,model_dir,batch_szie,tr_x,tr_y,val_x,val_y,lr_rate):
saver = tf.train.Saver()
epoch = 0
max_epoch = 300
print ("starting training")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  while epoch < max_epoch:
    for br_x,br_y in shuffle_batch(tr_x,tr_y,batch_size):
      sess.run(model.opt,feed_dict={model.inputs:br_x,model.outputs:br_y,model.is_training:True,model.lr_rate:lr_rate})  

    epoch = epoch + 1
  saver.save(sess,model_dir,write_meta_graph=True)
  print ('completed training, model saved')

这就是我所谓的训练功能。 custom_dnn()函数创建一个自定义类的对象,然后调用train函数。该模型已经正确训练,我已经可视化了训练图,但是保存后,只有检查点文件保存在本地。

model_dir = "/Users/proj/model/"
batch_size = 50
lr_rate = 0.001
tf.reset_default_graph()
model_sub1 = custom_dnn()
train(model_sub1,model_dir,batch_size,tr_x,tr_y,val_x,val_y,lr_rate)

我在做什么错了?

0 个答案:

没有答案