使用Keras的神经网络

时间:2020-01-04 22:57:37

标签: python keras neural-network

有人可以帮我使用keras.Sequential()翻译这个简单的神经网络吗? enter image description here

我基本上想知道如何定义一个神经网络,该神经网络为下一层的每个节点(而不是第一层的每个节点连接到第二层的每个节点)提供3个单独的输入节点。 我也不知道训练数据的数组应该如何成形。

1 个答案:

答案 0 :(得分:2)

基于https://keras.io/models/modelhttps://keras.io/layers/merge/

from keras.models import Model
from keras.layers import Input, Dense, Concatenate

a0 = Input(shape=(3,))
a1 = Input(shape=(3,))
a2 = Input(shape=(3,))
a3 = Input(shape=(3,))

b0 = Dense(1)(a0)
b1 = Dense(1)(a1)
b2 = Dense(1)(a2)
b3 = Dense(1)(a3)

b_concat = Concatenate(axis=-1)([b0, b1, b2, b3])

c = Dense(1)(b_concat)

model = Model(inputs=[a0, a1, a2, a3], outputs=[c])

model.compile(loss='mean_squared_error', optimizer='sgd')

model.summary()

礼物:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 3)]          0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 3)]          0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None, 3)]          0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 3)]          0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 1)            4           input_1[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1)            4           input_2[0][0]                    
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 1)            4           input_3[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 1)            4           input_4[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 4)            0           dense[0][0]                      
                                                                 dense_1[0][0]                    
                                                                 dense_2[0][0]                    
                                                                 dense_3[0][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 1)            5           concatenate[0][0]                
==================================================================================================
Total params: 21
Trainable params: 21
Non-trainable params: 0

但是这种模型(此处没有激活功能)非常简单,也许“经典”机器学习方法可能更易于实现(请参见https://scikit-learn.org/stable/supervised_learning.html#supervised-learning)。