我在Tensorflow中编写了以下卷积神经网络(CNN)类[为了清楚起见,我试图省略一些代码行。]
class CNN:
def __init__(self,
num_filters=16, # initial number of convolution filters
num_layers=5, # number of convolution layers
num_input=2, # number of channels in input
num_output=5, # number of channels in output
learning_rate=1e-4, # learning rate for the optimizer
display_step = 5000, # displays training results every display_step epochs
num_epoch = 10000, # number of epochs for training
batch_size= 64, # batch size for mini-batch processing
restore_file=None, # restore file (default: None)
):
# define placeholders
self.image = tf.placeholder(tf.float32, shape = (None, None, None,self.num_input))
self.groundtruth = tf.placeholder(tf.float32, shape = (None, None, None,self.num_output))
# builds CNN and compute prediction
self.pred = self._build()
# I have already created a tensorflow session and saver objects
self.sess = tf.Session()
self.saver = tf.train.Saver()
# also, I have defined the loss function and optimizer as
self.loss = self._loss_function()
self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.loss)
if restore_file is not None:
print("model exists...loading from the model")
self.saver.restore(self.sess,restore_file)
else:
print("model does not exist...initializing")
self.sess.run(tf.initialize_all_variables())
def _build(self):
#builds CNN
def _loss_function(self):
# computes loss
#
def train(self, train_x, train_y, val_x, val_y):
# uses mini batch to minimize the loss
self.sess.run(self.optimizer, feed_dict = {self.image:sample, self.groundtruth:gt})
# I save the session after n=10 epochs as:
if epoch%n==0:
self.saver.save(sess,'snapshot',global_step = epoch)
# finally my predict function is
def predict(self, X):
return self.sess.run(self.pred, feed_dict={self.image:X})
我已经为两个单独的任务独立训练了两个CNN。每人约需1天。比如说,model1和model2分别保存为“snapshot-model1-10000
”和“snapshot-model2-10000
”(及其对应的元文件)。我可以测试每个模型并分别计算其性能。
现在,我想在一个脚本中加载这两个模型。我自然会尝试做如下:
cnn1 = CNN(..., restore_file='snapshot-model1-10000',..........)
cnn2 = CNN(..., restore_file='snapshot-model2-10000',..........)
我遇到错误[错误消息很长。我刚刚复制/粘贴了它的片段。]
NotFoundError: Tensor name "Variable_26/Adam_1" not found in checkpoint files /home/amitkrkc/codes/A549_models/snapshot-hela-95000
[[Node: save_1/restore_slice_85 = RestoreSlice[dt=DT_FLOAT, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_1/Const_0, save_1/restore_slice_85/tensor_name, save_1/restore_slice_85/shape_and_slice)]]
有没有办法从这两个文件中加载两个独立的CNN?欢迎任何建议/意见/反馈。
谢谢,
答案 0 :(得分:16)
是的。使用单独的图表。
g1 = tf.Graph()
g2 = tf.Graph()
with g1.as_default():
cnn1 = CNN(..., restore_file='snapshot-model1-10000',..........)
with g2.as_default():
cnn2 = CNN(..., restore_file='snapshot-model2-10000',..........)
编辑:
如果您想将它们放入相同的图形中。您必须重命名一些变量。一个想法是让每个CNN处于不同的范围内,让saver处理该范围内的变量,例如:
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), scope='model1')
并在cnn中包装范围内的所有构造:
with tf.variable_scope('model1'):
...
EDIT2:
其他想法是重命名保存管理的变量(因为我假设您想要使用已保存的检查点而不重新训练所有内容。保存允许在图形和检查点中使用不同的变量名称,请查看初始化文档。
答案 1 :(得分:4)
这应该是对投票最多的答案的评论。但是我没有足够的声誉来做到这一点。
无论如何。 如果您(任何人搜索并到此为止)仍无法解决lpp提供的解决方案,并且您正在使用Keras,请检查github中的以下引用。
这是因为如果不提供默认tf会话,则keras会共享一个全局会话
创建model1时,它在graph1上 当model1加载权重时,权重在与graph1相关联的keras全局会话上
创建model2时,它在graph2上 当model2加载权重时,全局会话不知道graph2
以下解决方案可能会有所帮助,
graph1 = Graph()
with graph1.as_default():
session1 = Session()
with session1.as_default():
with open('model1_arch.json') as arch_file:
model1 = model_from_json(arch_file.read())
model1.load_weights('model1_weights.h5')
# K.get_session() is session1
# do the same for graph2, session2, model2
答案 2 :(得分:0)
我遇到了同样的问题,无法解决问题(没有再培训)我在互联网上找到的任何解决方案。所以我所做的是将每个模型加载到两个与主线程通信的独立线程中。编写代码很简单,在同步线程时只需要小心。 在我的情况下,每个线程收到其问题的输入并返回到主线程的输出。它没有任何可观察到的开销。
答案 3 :(得分:0)
如果要连续训练或加载多个模型,一种方法是清除会话。您可以使用
轻松完成此操作.followers, starred, following {
display: table-cell;
float:left;
}
K.clear_session()破坏当前的TF图并创建一个新的TF图。 有助于避免旧模型/图层造成混乱。
答案 4 :(得分:0)
您需要创建2个会话并分别还原2个模型。为了使其正常工作,您需要执行以下操作:
1a。保存模型时,需要将范围添加到变量名称。这样,您将知道哪些变量属于哪个模型:
ALTER FUNCTION [dbo].[function_Json_by_shumail](@POSITIONID INT, @IsRoot INT )
RETURNS VARCHAR(MAX)
BEGIN
DECLARE @Json VARCHAR(MAX) = '{}', @Name NVARCHAR(MAX), @Title NVARCHAR(MAX), @Phone NVARCHAR(MAX), @Company NVARCHAR(MAX), @Email NVARCHAR(MAX) , @Children NVARCHAR(MAX)
SET @Json =
(SELECT P.Name as name , ISNULL(P.POSTIONTITTLE,'') as title, ISNULL(P.primarycontactemail,'') as email , ISNULL(P.PRIMARYCONTACTPHONE,'') as phone , ISNULL(P.[COMPANY NAME],'') as company ,JSON_QUERY(dbo.function_Json_by_shumail(P.POSITIONID, 2) ) AS children
FROM [dbo].[VW_WorkerHierarchyUpdated] AS P
WHERE P.PARENTPOSITIONID = @POSITIONID
FOR JSON AUTO);
IF(@IsRoot = 0)
BEGIN
SELECT @Name = P.Name FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.PARENTPOSITIONID is Null
SELECT @Title = ISNULL(P.POSTIONTITTLE,'') FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.PARENTPOSITIONID is Null
SELECT @Phone = ISNULL(P.PRIMARYCONTACTPHONE,'') FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.PARENTPOSITIONID is Null
SELECT @Company = ISNULL(P.[COMPANY NAME],'') FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.PARENTPOSITIONID is Null
SELECT @Email = ISNULL(P.primarycontactemail,'') FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.PARENTPOSITIONID is Null
SET @Json = '{"name":"' + @Name + '","title":"' + @Title + '","email":"' + @Email + '","phone":"' + @Phone + '","company":"' + @Company + '","children":' + CAST(@Json AS VARCHAR(MAX)) + '}'
SET @IsRoot = 1
END
IF(@IsRoot = @POSITIONID)
BEGIN
SELECT @Name = P.Name FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.POSITIONID = @POSITIONID
SELECT @Title = ISNULL(P.POSTIONTITTLE,'') FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.POSITIONID = @POSITIONID
SELECT @Phone = ISNULL(P.PRIMARYCONTACTPHONE,'') FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.POSITIONID = @POSITIONID
SELECT @Company = ISNULL(P.[COMPANY NAME],'') FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.POSITIONID = @POSITIONID
SELECT @Email = ISNULL(P.primarycontactemail,'') FROM [dbo].[VW_WorkerHierarchyUpdated] AS P WHERE P.POSITIONID = @POSITIONID
SET @Json = '{"name":"' + @Name + '","title":"' + @Title + '","email":"' + @Email + '","phone":"' + @Phone + '","company":"' + @Company + '","children":' + CAST(@Json AS VARCHAR(MAX)) + '}'
SET @IsRoot = 1
END
RETURN @Json
END
1b。另外,如果您已经保存了模型,则可以通过使用this script添加作用域来重命名变量。
2 ..恢复不同的模型时,需要按如下所示的变量名进行过滤:
# The first model
tf.Variable(tf.zeros([self.batch_size]), name="model_1/Weights")
...
# The second model
tf.Variable(tf.zeros([self.batch_size]), name="model_2/Weights")
...