部署张量流服务模型时Tensorflow中的运行时错误

时间:2018-03-29 10:51:09

标签: python tensorflow

我写了一个张量流代码,这个代码指的是张量流服务模型,它可以从网络和网站访问。我正在使用Flask API进行Web访问,当我运行时,它说“尝试使用封闭会话”。有谁告诉我如何解决这个问题?“ 这是代码:

import tensorflow as tf
import numpy as np
import os, sys

DATA_SIZE = 100
SAVE_PATH = './save'
EPOCHS = 1000
LEARNING_RATE = 0.01
MODEL_NAME = 'test'

if not os.path.exists(SAVE_PATH):
    os.mkdir(SAVE_PATH)

data = (np.random.rand(DATA_SIZE, 2), np.random.rand(DATA_SIZE, 1))
test = (np.random.rand(DATA_SIZE // 8, 2), np.random.rand(DATA_SIZE // 8, 1))

tf.reset_default_graph()

x = tf.placeholder(tf.float32, shape=[None, 2], name='inputs')
y = tf.placeholder(tf.float32, shape=[None, 1], name='targets')

net = tf.layers.dense(x, 16, activation=tf.nn.relu)
net = tf.layers.dense(net, 16, activation=tf.nn.relu)
pred = tf.layers.dense(net, 1, activation=tf.nn.sigmoid, name='prediction')

loss = tf.reduce_mean(tf.squared_difference(y, pred), name='loss')
train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)

checkpoint = tf.train.latest_checkpoint(SAVE_PATH)
should_train = checkpoint == None

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    if should_train:
        print("Training")
        saver = tf.train.Saver()
        for epoch in range(EPOCHS):
            _, curr_loss = sess.run([train_step, loss], feed_dict={x: data[0], y: data[1]})
            print('EPOCH = {}, LOSS = {:0.4f}'.format(epoch, curr_loss))
        path = saver.save(sess, SAVE_PATH + '/' + MODEL_NAME + '.ckpt')
        print("saved at {}".format(path))
    else:
        print("Restoring")
        graph = tf.get_default_graph()
        saver = tf.train.import_meta_graph(checkpoint + '.meta')
        saver.restore(sess, checkpoint)

        loss = graph.get_tensor_by_name('loss:0')

        test_loss = sess.run(loss, feed_dict={'inputs:0': test[0], 'targets:0': test[1]})
        print(sess.run(pred, feed_dict={'inputs:0': np.random.rand(10,2)}))
        print("TEST LOSS = {:0.4f}".format(test_loss))

import tensorflow as tf
import os

SAVE_PATH = './save'
MODEL_NAME = 'test'
VERSION = 1
SERVE_PATH = './serve/{}/{}'.format(MODEL_NAME, VERSION)

checkpoint = tf.train.latest_checkpoint(SAVE_PATH)

tf.reset_default_graph()

with tf.Session() as sess:
    # import the saved graph
    saver = tf.train.import_meta_graph(checkpoint + '.meta')
    # get the graph for this session
    graph = tf.get_default_graph()
    sess.run(tf.global_variables_initializer())
    # get the tensors that we need
    inputs = graph.get_tensor_by_name('inputs:0')
    predictions = graph.get_tensor_by_name('prediction/Sigmoid:0')

# create tensors info
model_input = tf.saved_model.utils.build_tensor_info(inputs)
model_output = tf.saved_model.utils.build_tensor_info(predictions)

# build signature definition
signature_definition = tf.saved_model.signature_def_utils.build_signature_def(
    inputs={'inputs': model_input},
    outputs={'outputs': model_output},
    method_name= tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

builder = tf.saved_model.builder.SavedModelBuilder(SERVE_PATH)

builder.add_meta_graph_and_variables(
    sess, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
        tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            signature_definition
    })
# Save the model so we can serve it with a model server :)
builder.save()

运行此代码后,我收到以下错误。

RuntimeError                              Traceback (most recent call last)
<ipython-input-9-0aab8e5ba747> in <module>()
      5     signature_def_map={
      6         tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
----> 7             signature_definition
      8     })
      9 # Save the model so we can serve it with a model server :)

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\saved_model\builder_impl.py in add_meta_graph_and_variables(self, sess, tags, signature_def_map, assets_collection, legacy_init_op, clear_devices, main_op)
    377     # SavedModel can be copied or moved, this avoids the checkpoint state to
    378     # become outdated.
--> 379     saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
    380 
    381     # Export the meta graph def.

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py in save(self, sess, save_path, global_step, latest_filename, meta_graph_suffix, write_meta_graph, write_state)
   1571           model_checkpoint_path = sess.run(
   1572               self.saver_def.save_tensor_name,
-> 1573               {self.saver_def.filename_tensor_name: checkpoint_file})
   1574         else:
   1575           self._build_eager(

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
    887     try:
    888       result = self._run(None, fetches, feed_dict, options_ptr,
--> 889                          run_metadata_ptr)
    890       if run_metadata:
    891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1045     # Check session.
   1046     if self._closed:
-> 1047       raise RuntimeError('Attempted to use a closed Session.')
   1048     if self.graph.version == 0:
   1049       raise RuntimeError('The Session graph is empty.  Add operations to the '

RuntimeError: Attempted to use a closed Session.

我无法追踪此错误。任何人都可以解决这个问题。事先提前。

0 个答案:

没有答案