
时间:2019-07-10 17:38:23

标签: python tensorflow keras

我正在尝试使用Keras重写Tensorflow网络。 Tensorflow中的模型定义为

def keras_version():
    input = Input(shape=(135,), name='feature_input')
    out1 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(45, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Reshape((9, 5))(out1)

    out2 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(540, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Reshape((9, 4, 15))(out2)
    out2 = Lambda(lambda x: K.dot(K.permute_dimensions(x, (0, 2, 1, 3)), K.permute_dimensions(x, (0, 2, 3, 1))), output_shape=(4,9,9))(out2)
    out2 = Flatten()(out2)
    out2 = Dense(324, kernel_initializer='glorot_normal', activation='linear')(out2)
    # K.dot should be of size (-1, 4, 9, 9), so I set the output size to 324, and later on, reshaped data
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Reshape((4, 9, 9))(out2)
    out2 = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 3, 1)))(out2)

    out1 = Lambda(identity, name='output_1')(out1)
    out2 = Lambda(identity, name='output_2')(out2)

    return Model(input, [out1, out2])


Layer (type)                    Output Shape         Param #     Connected to                     
feature_input (InputLayer)      (None, 135)          0                                            
dense_5 (Dense)                 (None, 128)          17408       feature_input[0][0]              
leaky_re_lu_5 (LeakyReLU)       (None, 128)          0           dense_5[0][0]                    
dense_6 (Dense)                 (None, 256)          33024       leaky_re_lu_5[0][0]              
leaky_re_lu_6 (LeakyReLU)       (None, 256)          0           dense_6[0][0]                    
dense_7 (Dense)                 (None, 512)          131584      leaky_re_lu_6[0][0]              
leaky_re_lu_7 (LeakyReLU)       (None, 512)          0           dense_7[0][0]                    
dense_1 (Dense)                 (None, 128)          17408       feature_input[0][0]              
dense_8 (Dense)                 (None, 540)          277020      leaky_re_lu_7[0][0]              
leaky_re_lu_1 (LeakyReLU)       (None, 128)          0           dense_1[0][0]                    
leaky_re_lu_8 (LeakyReLU)       (None, 540)          0           dense_8[0][0]                    
dense_2 (Dense)                 (None, 256)          33024       leaky_re_lu_1[0][0]              
reshape_2 (Reshape)             (None, 9, 4, 15)     0           leaky_re_lu_8[0][0]              
leaky_re_lu_2 (LeakyReLU)       (None, 256)          0           dense_2[0][0]                    
lambda_1 (Lambda)               (None, 4, 9, 9)      0           reshape_2[0][0]                  
dense_3 (Dense)                 (None, 512)          131584      leaky_re_lu_2[0][0]              
flatten_1 (Flatten)             (None, 324)          0           lambda_1[0][0]                   
leaky_re_lu_3 (LeakyReLU)       (None, 512)          0           dense_3[0][0]                    
dense_9 (Dense)                 (None, 324)          105300      flatten_1[0][0]                  
dense_4 (Dense)                 (None, 45)           23085       leaky_re_lu_3[0][0]              
leaky_re_lu_9 (LeakyReLU)       (None, 324)          0           dense_9[0][0]                    
leaky_re_lu_4 (LeakyReLU)       (None, 45)           0           dense_4[0][0]                    
reshape_3 (Reshape)             (None, 4, 9, 9)      0           leaky_re_lu_9[0][0]              
reshape_1 (Reshape)             (None, 9, 5)         0           leaky_re_lu_4[0][0]              
lambda_2 (Lambda)               (None, 9, 9, 4)      0           reshape_3[0][0]                  
output_1 (Lambda)               (None, 9, 5)         0           reshape_1[0][0]                  
output_2 (Lambda)               (None, 9, 9, 4)      0           lambda_2[0][0]                   
Total params: 769,437
Trainable params: 769,437
Non-trainable params: 0


  1. 定义图层尺寸的方式。
  2. 权重的初始化方式。
  3. 展平矩阵乘法并重新整形的方式。




0 个答案:
