当前正在尝试使此repo有效。
我正在尝试将经过训练的模型保存在本地计算机中,以便以后使用。我读过tensorflow的doc,通过调用tf.save_model.save(object)
来保存模型似乎很直观。但是我不确定如何申请。
原始代码在这里:model.py 以下是我的更改:
import tensorflow as tf
class ICON(tf.Module): # make it a tensorflow modul
def __init__(self, config, embeddingMatrix, session=None):
def _build_inputs(self):
def _build_vars(self):
def _convolution(self, input_to_conv):
def _inference(self):
def batch_fit(self, queries, ownHistory, otherHistory, labels):
feed_dict = {self._input_queries: queries, self._own_histories: ownHistory, self._other_histories: otherHistory,
self._labels: labels}
loss, _ = self._sess.run([self.loss_op, self.train_op], feed_dict=feed_dict)
return loss
def predict(self, queries, ownHistory, otherHistory, ):
feed_dict = {self._input_queries: queries, self._own_histories: ownHistory, self._other_histories: otherHistory}
return self._sess.run(self.predict_op, feed_dict=feed_dict)
def save(self): # attempt to save the model
tf.saved_model.save(
self, './output/model')
上面的代码产生ValueError,如下所示:
ValueError: Tensor("ICON/CNN/embedding_matrix:0", shape=(16832, 300), dtype=float32_ref) must be from the same graph as Tensor("saver_filename:0", shape=(), dtype=string).
答案 0 :(得分:1)
我相信您可以为此使用tf.train.Saver类
def save(self): # attempt to save the model
saver = tf.train.Saver()
saver.save(self._sess, './output/model')
然后您可以通过这种方式还原模型
saver = tf.train.import_meta_graph('./output/model.meta')
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('./output'))
您可能还会发现此tutorial有助于进一步了解这一点。
编辑:如果要使用SavedModel
def save(self):
inputs = {'input_queries': self._input_queries, 'own_histories': self._own_histories, 'other_histories': self._other_histories}
outputs = {'output': self.predict_op}
tf.saved_model.simple_save(self._sess, './output/model', inputs, outputs)
然后您可以使用SavedModel使用tf.contrib.predictor.from_saved_model加载并提供服务
from tensorflow.contrib.predictor import from_saved_model
predictor = from_saved_model('./output/model')
predictions = predictor({'input_queries': input_queries, 'own_histories': own_histories, 'other_histories': other_histories})