我写了一个张量流代码,这个代码指的是张量流服务模型,它可以从网络和网站访问。我正在使用Flask API进行Web访问,当我运行时,它说“尝试使用封闭会话”。有谁告诉我如何解决这个问题?“ 这是代码:
import tensorflow as tf
import numpy as np
import os, sys
DATA_SIZE = 100
SAVE_PATH = './save'
EPOCHS = 1000
LEARNING_RATE = 0.01
MODEL_NAME = 'test'
if not os.path.exists(SAVE_PATH):
os.mkdir(SAVE_PATH)
data = (np.random.rand(DATA_SIZE, 2), np.random.rand(DATA_SIZE, 1))
test = (np.random.rand(DATA_SIZE // 8, 2), np.random.rand(DATA_SIZE // 8, 1))
tf.reset_default_graph()
x = tf.placeholder(tf.float32, shape=[None, 2], name='inputs')
y = tf.placeholder(tf.float32, shape=[None, 1], name='targets')
net = tf.layers.dense(x, 16, activation=tf.nn.relu)
net = tf.layers.dense(net, 16, activation=tf.nn.relu)
pred = tf.layers.dense(net, 1, activation=tf.nn.sigmoid, name='prediction')
loss = tf.reduce_mean(tf.squared_difference(y, pred), name='loss')
train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
checkpoint = tf.train.latest_checkpoint(SAVE_PATH)
should_train = checkpoint == None
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
if should_train:
print("Training")
saver = tf.train.Saver()
for epoch in range(EPOCHS):
_, curr_loss = sess.run([train_step, loss], feed_dict={x: data[0], y: data[1]})
print('EPOCH = {}, LOSS = {:0.4f}'.format(epoch, curr_loss))
path = saver.save(sess, SAVE_PATH + '/' + MODEL_NAME + '.ckpt')
print("saved at {}".format(path))
else:
print("Restoring")
graph = tf.get_default_graph()
saver = tf.train.import_meta_graph(checkpoint + '.meta')
saver.restore(sess, checkpoint)
loss = graph.get_tensor_by_name('loss:0')
test_loss = sess.run(loss, feed_dict={'inputs:0': test[0], 'targets:0': test[1]})
print(sess.run(pred, feed_dict={'inputs:0': np.random.rand(10,2)}))
print("TEST LOSS = {:0.4f}".format(test_loss))
import tensorflow as tf
import os
SAVE_PATH = './save'
MODEL_NAME = 'test'
VERSION = 1
SERVE_PATH = './serve/{}/{}'.format(MODEL_NAME, VERSION)
checkpoint = tf.train.latest_checkpoint(SAVE_PATH)
tf.reset_default_graph()
with tf.Session() as sess:
# import the saved graph
saver = tf.train.import_meta_graph(checkpoint + '.meta')
# get the graph for this session
graph = tf.get_default_graph()
sess.run(tf.global_variables_initializer())
# get the tensors that we need
inputs = graph.get_tensor_by_name('inputs:0')
predictions = graph.get_tensor_by_name('prediction/Sigmoid:0')
# create tensors info
model_input = tf.saved_model.utils.build_tensor_info(inputs)
model_output = tf.saved_model.utils.build_tensor_info(predictions)
# build signature definition
signature_definition = tf.saved_model.signature_def_utils.build_signature_def(
inputs={'inputs': model_input},
outputs={'outputs': model_output},
method_name= tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder = tf.saved_model.builder.SavedModelBuilder(SERVE_PATH)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_definition
})
# Save the model so we can serve it with a model server :)
builder.save()
运行此代码后,我收到以下错误。
RuntimeError Traceback (most recent call last)
<ipython-input-9-0aab8e5ba747> in <module>()
5 signature_def_map={
6 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
----> 7 signature_definition
8 })
9 # Save the model so we can serve it with a model server :)
~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\builder_impl.py in add_meta_graph_and_variables(self, sess, tags, signature_def_map, assets_collection, legacy_init_op, clear_devices, main_op)
377 # SavedModel can be copied or moved, this avoids the checkpoint state to
378 # become outdated.
--> 379 saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
380
381 # Export the meta graph def.
~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py in save(self, sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph, write_state)
1571 model_checkpoint_path = sess.run(
1572 self.saver_def.save_tensor_name,
-> 1573 {self.saver_def.filename_tensor_name: checkpoint_file})
1574 else:
1575 self._build_eager(
~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
887 try:
888 result = self._run(None, fetches, feed_dict, options_ptr,
--> 889 run_metadata_ptr)
890 if run_metadata:
891 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1045 # Check session.
1046 if self._closed:
-> 1047 raise RuntimeError('Attempted to use a closed Session.')
1048 if self.graph.version == 0:
1049 raise RuntimeError('The Session graph is empty. Add operations to the '
RuntimeError: Attempted to use a closed Session.
我无法追踪此错误。任何人都可以解决这个问题。事先提前。