嗨,我有一个Python脚本,其中我实例化了神经网络类的两个对象。 每个对象定义自己的会话,并提供保存图形的方法。
import tensorflow as tf
import os, shutil
class TestNetwork:
def __init__(self, id):
self.id = id
tf.reset_default_graph()
self.s = tf.placeholder(tf.float32, [None, 2], name='s')
w_initializer, b_initializer = tf.random_normal_initializer(0., 1.0), tf.constant_initializer(0.1)
self.k = tf.layers.dense(self.s, 2, kernel_initializer=w_initializer,
bias_initializer=b_initializer, name= 'k')
'''Defines self.session and initialize the variables'''
session_conf = tf.ConfigProto(
allow_soft_placement = True,
log_device_placement = False)
self.session = tf.Session(config = session_conf)
self.session.run(tf.global_variables_initializer())
def save_model(self, output_dir):
'''Save the network graph and weights to disk'''
if os.path.exists(output_dir):
# if provided output_dir already exists, remove it
shutil.rmtree(output_dir)
builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
builder.add_meta_graph_and_variables(
self.session,
[tf.saved_model.tag_constants.SERVING],
clear_devices=True)
# create a new directory output_dir and store the saved model in it
builder.save()
t1 = TestNetwork(1)
t2 = TestNetwork(2)
t1.save_model("t1_model")
t2.save_model("t2_model")
我得到的错误是
TypeError:无法将feed_dict键解释为张量:名称 'save / Const:0'表示不存在的张量。操作, 图中不存在“保存/常量”。
我读到一些东西说这个错误是由于tf.train.Saver
造成的。
因此,我在__init__
方法的末尾添加了以下行:
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 5)
但是我仍然收到错误消息。
答案 0 :(得分:2)
tf.reset_default_graph
将清除默认图形堆栈并重置全局默认图形。
注意:默认图形是当前线程的属性。这个 函数仅适用于当前线程。调用此功能 当tf.Session或tf.InteractiveSession处于活动状态时将导致 未定义的行为。 使用任何以前创建的tf.Operation或 tf.Tensor对象在调用此函数后将导致未定义 行为。
您应该分别指定Graph
,并在相应的图形范围内定义所有这些内容。
def __init__(self, id):
self.id = id
self.graph = tf.Graph()
with self.graph.as_default():
self.s = tf.placeholder(tf.float32, [None, 2], name='s')
w_initializer, b_initializer = tf.random_normal_initializer(0., 1.0), tf.constant_initializer(0.1)
self.k = tf.layers.dense(self.s, 2, kernel_initializer=w_initializer,
bias_initializer=b_initializer, name= 'k')
init = tf.global_variables_initializer()
'''Defines self.session and initialize the variables'''
session_conf = tf.ConfigProto(
allow_soft_placement = True,
log_device_placement = False)
self.session = tf.Session(config = session_conf,graph=self.graph)
self.session.run(init)
tf.train.Saver
是保存模型变量的另一种方法。
修改 如果您获得空的“变量”,则应将模型保存在图形中:
def save_model(self, output_dir):
'''Save the network graph and weights to disk'''
if os.path.exists(output_dir):
# if provided output_dir already exists, remove it
shutil.rmtree(output_dir)
with self.graph.as_default():
builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
builder.add_meta_graph_and_variables(
self.session,
[tf.saved_model.tag_constants.SERVING],
clear_devices=True)
# create a new directory output_dir and store the saved model in it
builder.save()