我有以下代码将DNN模型分为两部分。
def split(model, input):
starting_layer_name = input
new_input = layers.Input( batch_shape=model.get_layer(starting_layer_name).get_input_shape_at(0))
layer_outputs = {}
def get_output_of_layer(layer):
if layer.name in layer_outputs:
return layer_outputs[layer.name]
if layer.name == starting_layer_name:
out = layer( new_input )
layer_outputs[layer.name] = out
return out
prev_layers = []
for node in layer._inbound_nodes:
prev_layers.extend( node.inbound_layers )
pl_outs = []
for pl in prev_layers:
pl_outs.extend( [get_output_of_layer( pl )] )
out = layer( pl_outs[0] if len( pl_outs ) == 1 else pl_outs )
layer_outputs[layer.name] = out
return out
if starting_layer_name=='input_1':
new_output = get_output_of_layer(model.layers[-21])
block_1 = models.Model( new_input, new_output )
return block_1
elif starting_layer_name=='block1_pool':
new_output =get_output_of_layer((model.layers[-1]))
block_2 = models.Model(new_input, new_output)
return block_2
block_1=split(model,'input_1')
block_2=split(model,'block1_pool')
block_1.save('my_model1.h5')
block_2.save('my_model2.h5')
当我尝试运行以下代码时,我检索到“图形断开连接无法获取张量值。
from Keras.models import load_model
model = load_model('my_model1.h5')
model.summary()
在解决此问题方面提供帮助非常感谢。我正在尝试拆分模型的当前方法给我一个错误,是否有另一种方法可以解决keras中的此问题。
答案 0 :(得分:0)
我已经更新了分割点的代码,现在load方法可以正常工作了。请在下面找到更新的代码
def split_keras_model(model, index):
layer_input_1 = Input(model.layers[0].input_shape[1:])
x = layer_input_1
# Foreach layer: connect it to the new model
for layer in model.layers[1:index]:
x = layer(x)
model1 = Model(inputs=layer_input_1, outputs=x)
input_shape_2 = model.layers[index].get_input_shape_at(0)[1:]
print("Input shape of model 2: " + str(input_shape_2))
# A new input tensor to be able to feed the desired layer
layer_input_2 = Input(shape=input_shape_2)
x = layer_input_2
for layer in model.layers[index:]:
x = layer(x)
model2 = Model(inputs=layer_input_2, outputs=x)
return (model1, model2)
model1,model2=split_keras_model(model,2)
model1.save('my_model1.h5')
loaded_model=load_model('my_model1.h5',compile=False)