在python中创建Keras层的树结构

时间:2018-10-24 09:57:03

标签: keras

我正在尝试获取Keras层的结构。在一个简单的网络中,仅通过遍历model.layers就可以做到这一点。但是,如果网络很复杂(例如,连接不同的层),则这种方法是不够的。 这是一个示例:

FEATURES = ['A','B','C','D']
IMPORTANT_FEATURES = [0, 3]
NORMAL_FEATURES = [1, 2]
inputLayer = [Input(shape=(1, )) for i,f in enumerate(FEATURES)]
importantInput = keras.layers.Concatenate(axis=-1)([inputLayer[i] for i in IMPORTANT_FEATURES])
layer1 = Dense()(importantInput)
normalInput  = keras.layers.Concatenate(axis=-1)(layer1 + [inputLayer[i] for i in NORMAL_FEATURES])
layer2 = Dense()(normalInput)
model = Model([
    inputLayer[i]
    for i in range(len(FEATURES))
], layer2)

这会在model.layer中生成由以下组成的 Keras 层的列表:

[Input1, Input2, Concatenate1, Dense1, Input3, Input4, Concatenate2, Dense2]

获取级联输入的唯一方法是访问Concatenate1.input。但是,这会返回 Tensorflow 图层的列表。 是否可以仅使用 Keras 图层来获得图层的树结构?

1 个答案:

答案 0 :(得分:0)

输入层的 Keras 列表可通过éèàâê

获得