如何按摩Keras框架的输入?

时间:2016-06-21 03:52:38

标签: deep-learning keras lstm

我是keras的新手,尽管阅读了文档和keras中的examples文件夹,但我仍然在努力解决如何将所有内容融合在一起。

特别是,我想从一个简单的任务开始:我有一系列令牌,其中每个令牌只有一个标签。我有很多像这样的训练数据 - 实际上是无限的,因为我可以根据需要生成更多的(token, label)训练对。

我想构建一个网络来预测给予令牌的标签。令牌的数量必须始终与标签的数量相同(一个令牌=一个标签)。

我希望这是基于所有周围的令牌,比如在同一行或句子或窗口内 - 而不仅仅是在前面的令牌上。

我自己走了多远:

现在正在努力:

  1. 所有input_diminput_shape参数...因为每个句子都有不同的长度(不同数量的标记和标签),我应该为input_dim添加什么输入层?
  2. 如何告诉网络使用整个令牌句进行预测,而不仅仅是一个令牌?如何在给定一系列令牌的情况下预测整个标签序列,而不仅仅是基于先前令牌的标签?
  3. 将文本拆分成句子或窗口是否有意义?或者我可以将整个文本的向量作为单个序列传递?什么是“序列”?
  4. 什么是“时间片”和“时间步长”?文档一直在提及,我不知道这与我的问题有什么关系。什么是keras的“时间”?
  5. 基本上我无法将文档中的概念(如“时间”或“序列”)与我的问题联系起来。像Keras#40这样的问题并没有让我更聪明。

    非常感谢指向网络上的相关示例或代码示例。不寻找学术文章。

    谢谢!

1 个答案:

答案 0 :(得分:2)

  1. 如果您有不同长度的序列,您可以填充它们或使用有状态的RNN实现,其中在批次之间保存激活。前者是最简单和最常用的。

  2. 如果您想在使用RNN时使用未来信息,您希望使用双向模型,在该模型中,您可以连接两个相反方向的RNN。 RNN将使用所有先前信息的表示,例如,预测。

  3. 如果您有很长的句子,那么对随机子序列进行采样并对其进行训练可能会有所帮助。 Fx 100个字符。这也有助于过度拟合。

  4. 时间步长是您的代币。句子是一系列字符/标记。

  5. 我写了一个例子,说明我如何理解你的问题,但它没有经过测试,所以它可能无法运行。如果有可能,我建议使用单热编码,而不是使用整数来表示您的数据,然后使用binary_crossentropy代替mse

    from keras.models import Model
    from keras.layers import Input, LSTM, TimeDistributed
    from keras.preprocessing import sequence
    
    # Make sure all sequences are of same length
    X_train = sequence.pad_sequences(X_train, maxlen=maxlen)
    
    # The input shape is your sequence length and your token embedding size (which is 1)
    inputs = Input(shape=(maxlen, 1))
    
    # Build a bidirectional RNN
    lstm_forward = LSTM(128)(inputs)
    lstm_backward = LSTM(128, go_backwards=True)(inputs)
    bidirectional_lstm = merge([lstm_forward, lstm_backward], mode='concat', concat_axis=2)
    
    # Output each timestep into a fully connected layer with linear 
    # output to map to an integer
    sequence_output = TimeDistributed(Dense(1, activation='linear'))(bidirectional_lstm)
    # Dense(n_classes, activation='sigmoid') if you want to classify
    
    model = Model(inputs, sequence_output)
    model.compile('adam', 'mse')
    model.fit(X_train, y_train)