在Tensorflow中还原已保存的神经网络

时间:2019-01-02 11:52:44

标签: python tensorflow

在将我的问题标记为重复之前,我想让你了解我已经经历了很多问题,但是那里的解决方案都无法消除我的疑虑并解决我的问题。我有一个要保存的受过训练的神经网络,以后使用此模型针对测试数据集对该模型进行测试。

我尝试保存并还原它,但是没有得到预期的结果。恢复似乎不起作用,可能是我使用错误,只是使用了全局变量初始值设定项给定的值。

这是我用来保存模型的代码。

 sess.run(tf.initializers.global_variables())
#num_epochs = 7
for epoch in range(num_epochs):  
  start_time = time.time()
  train_accuracy = 0
  train_loss = 0
  val_loss = 0
  val_accuracy = 0

  for bid in range(int(train_data_size/batch_size)):
     X_train_batch = X_train[bid*batch_size:(bid+1)*batch_size]
     y_train_batch = y_train[bid*batch_size:(bid+1)*batch_size]
     sess.run(optimizer, feed_dict = {x:X_train_batch, y:y_train_batch,prob:0.50})  

     train_accuracy = train_accuracy + sess.run(model_accuracy, feed_dict={x : X_train_batch,y:y_train_batch,prob:0.50})
     train_loss = train_loss + sess.run(loss_value, feed_dict={x : X_train_batch,y:y_train_batch,prob:0.50})

  for bid in range(int(val_data_size/batch_size)):
     X_val_batch = X_val[bid*batch_size:(bid+1)*batch_size]
     y_val_batch = y_val[bid*batch_size:(bid+1)*batch_size]
     val_accuracy = val_accuracy + sess.run(model_accuracy,feed_dict = {x:X_val_batch, y:y_val_batch,prob:0.75})
     val_loss = val_loss + sess.run(loss_value, feed_dict = {x:X_val_batch, y:y_val_batch,prob:0.75})

  train_accuracy = train_accuracy/int(train_data_size/batch_size)
  val_accuracy = val_accuracy/int(val_data_size/batch_size)
  train_loss = train_loss/int(train_data_size/batch_size)
  val_loss = val_loss/int(val_data_size/batch_size)


  end_time = time.time()


  saver.save(sess,'./blood_model_x_v2',global_step = epoch)  

保存模型后,将文件写到我的工作目录中,如下所示。

blood_model_x_v2-2.data-0000-of-0001
   blood_model_x_v2-2.index
   blood_model_x_v2-2.meta

类似地,v2-3,以此类推,直到v2-6,然后是“ checkpoint”文件。然后,我尝试使用此代码段(在初始化后)恢复它,但与预期的结果不同。我在做什么错了?

saver = tf.train.import_meta_graph('blood_model_x_v2-5.meta')
saver.restore(test_session,tf.train.latest_checkpoint('./'))

1 个答案:

答案 0 :(得分:2)

根据tensorflow docs

  

还原   恢复以前保存的变量。

     

此方法运行由构造函数添加的用于还原的操作   变量。它需要一个启动图形的会话。的   要还原的变量不必初始化,因为   恢复本身就是初始化变量的一种方式。

让我们看一个例子:

我们保存与此类似的模型:

Properties properties = new Properties();
properties.load(MyClass.class.getResourceAsStream("/foo.properties");
String zimboom = properties.getProperty("zimboom");

然后使用以下内容加载训练后的模型:

import tensorflow as tf

# Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# Define a test operation that we will restore
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object which will save all the variables
saver = tf.train.Saver()

# Run the operation by feeding input
print (sess.run(w4, feed_dict))
# Prints 24 which is sum of (w1+w2)*b1

# Now, save the graph
saver.save(sess, './ckpnt/my_test_model', global_step=1000)

如您所见,我们没有在恢复部分初始化会话。使用Checkpoint可以更好地保存和还原模型,这使您可以检查模型是否正确还原。