Keras中可变长度序列的Softmax

时间:2018-05-28 22:33:06

标签: tensorflow keras

我有2D形状(say [a, 10])的样本。从样品到样品的变化。我正在使用batch size = 1进行培训,以避免批量变量的问题。我创建了以下LSTM网络。现在的问题是我的目标是形状[1,a,1]的概率向量。每个样本的概率向量之和为1。

我想在最后一层上应用softmax激活,以便我可以将其与目标进行比较。我该怎么办?

    Layer (type)                            Output Shape                        Param #       
==========================================================================================
lstm_21 (LSTM)                          (1, None, 32)                       7808          
__________________________________________________________________________________________
lstm_22 (LSTM)                          (1, None, 8)                        1312          
__________________________________________________________________________________________
time_distributed_6 (TimeDistributed)    (1, None, 1)                        9             
==========================================================================================

这是我的代码

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.LSTM(32, return_sequences=True, batch_input_shape=(1, None, len(features))))
model.add(tf.keras.layers.LSTM(8, return_sequences=True))
model.add(tf.keras.layers.Dense(1, activation='softmax'))
# model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1, activation='softmax')))

print(model.summary(90))

model.compile(loss = 'mean_squared_error',
              optimizer = 'adam')


def generate_arrays_from_pd(df, arr_df):
    while True:
        for i in range(arr_df.shape[0]):
            a1 = arr_df[i, 0]
            a2 = arr_df[i, 1]
            batch_x = df.loc[a1:a2, features].as_matrix().reshape((1, -1, len(features)))
            batch_y = df.loc[a1:a2, "mkt_shr"].as_matrix().reshape((1, -1, 1))
            yield(batch_x, batch_y)

model.fit_generator(generate_arrays_from_pd(dat_train, arr_train), steps_per_epoch=arr_train.shape[0], epochs = 10, verbose=1, shuffle=False)

1 个答案:

答案 0 :(得分:0)

您可以添加全局最大池化层,然后添加具有softmax激活的密集层。

全局最大池化层采用步长维上的最大向量,因此不会有更多不同形状的数据,然后您可以应用softmax激活的密集层。