如何显示喀拉拉邦的每一层?

时间:2020-10-17 21:53:57

标签: python tensorflow keras deep-learning neural-network

我在喀拉拉邦有一个非常简单的模型。我定义了一个函数来获取vgg19网络,然后将其与平坦层和密集层连接。当我打印模型摘要时,它不会显示vgg19网络中的每一层。有什么方法可以显示而不更改vgg19的功能吗?任何建议表示赞赏。

import keras
from keras.layers import Input, Dense, Flatten
from keras.models import Model


input = Input(shape=(32,32,3), name="main_input")

def Model_vgg19(input_shape,input):
    vgg19_model = keras.applications.VGG19(
        input_shape=(32,32,3),
        weights='imagenet',
        include_top=False
    )(input)
    return vgg19_model

model = Model_vgg19((32,32,3),input)
model = Flatten()(model)
model = Dense(10, activation='relu', name='features_inc')(model)

model = Model(input, model)
model.summary()

结果就像

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
main_input (InputLayer)      (None, 32, 32, 3)         0         
_________________________________________________________________
vgg19 (Model)                (None, 1, 1, 512)         20024384  
_________________________________________________________________
flatten_1 (Flatten)          (None, 512)               0         
_________________________________________________________________
features_inc (Dense)         (None, 10)                5130      
=================================================================
Total params: 20,029,514
Trainable params: 20,029,514
Non-trainable params: 0
_________________________________________________________________

Process finished with exit code 0

2 个答案:

答案 0 :(得分:0)

您可以在模型类型上使用方法get_layer(name,index)。您可以在https://www.tensorflow.org/api_docs/python/tf/keras/Model#get_layer

中找到有关此信息的更多信息

对于您的代码,您可以使用以下代码:

model.get_layer(index=1).summary()

由此,您将获得VGG19模型的摘要(该模型的索引为1)。祝你好运!

答案 1 :(得分:0)

您可以尝试

vgg19_model = keras.applications.VGG19(
        input_shape=(32,32,3),
        weights='imagenet',
        include_top=False
    )
x=vgg19_model.output
x=Flatten()(x)
# use `output =` instead `output == `
output = Dense(10, activation='relu', name='features_inc')(x)
model=Model(inputs=vgg19_model.input, outputs=output)

请注意。 VGG模型具有参数池。如果设置pooling ='max',则输出层是最大池化层,您可以直接将其输入到密集层中,因此无需包括平坦层。