我正在尝试将两个字符级别的句子输入LSTM图层进行分类。我的示例与以下内容类似,我的标签是一个热门编码类。
标签:
label array([1., 0., 0., 0., 0.])
示例:
array([['0', ' '],
[' ', 'l'],
['1', 'e'],
['1', 't'],
['2', 'n'],
['8', 'i'],
[' ', ' '],
['"', ';'],
['h', 'h'],
['t', 's'],
['t', 'o'],
['p', 't'],
['s', 'n'],
[':', 'i'],
['/', 'c'],
['/', 'a'],
['w', 'm'],
['w', '('],
['w', ' '],
['.', '0'],
['e', '.'],
['x', '5'],
['a', '/'],
['m', 'a'],
['p', 'l'],
['l', 'l'],
['e', 'i'],
['.', 'z'],
['c', 'o'],
['o', 'm'],
['m', '"'],
['/', ' '],
['c', '"'],
['m', '/'],
['s', 'd'],
['/', 'a'],
['t', 'o'],
['i', 'l'],
['n', 'n'],
['a', 'w'],
['-', 'o'],
['a', 'd'],
['c', '-'],
['c', 'r'],
['e', 'o'],
['s', 'f'],
['s', '-'],
['-', 'r'],
['e', 'o'],
['d', 't'],
['i', 'i']], dtype='<U1')
我正在尝试使用Keras的嵌入层将字符映射到矢量中。然而,嵌入层仅采用单维序列。如何调整网络以采用多维序列?目前我有以下代码适用于单维样本。 51是我的窗口大小,74是我词汇量的大小。
model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=74,
output_dim=74,
input_length=51))
model.add(keras.layers.Dropout(0.2))
model.add(keras.layers.LSTM(64,
dropout=0.5,
recurrent_dropout=0.5,
return_sequences=True,
input_shape=(51, 74)))
model.add(keras.layers.LSTM(64,
dropout=0.5,
recurrent_dropout=0.5))
model.add(keras.layers.Dense(num_classes, activation='sigmoid'))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
答案 0 :(得分:0)
好的,我通过在嵌入之前添加重塑图层,然后在嵌入之后添加另一个重塑图层来解决此问题。这是代码:
model = keras.models.Sequential()
model.add(keras.layers.Reshape((2 * lstm_window_size, 1), input_shape=(
lstm_window_size, 2)))
model.add(keras.layers.Embedding(input_dim=vocab_size + 1,
output_dim=100,
input_length=lstm_window_size * 2))
model.add(keras.layers.Reshape((lstm_window_size, 200)))
model.add(keras.layers.Dropout(0.2))
model.add(keras.layers.LSTM(64,
dropout=0.5,
recurrent_dropout=0.5,
return_sequences=True,
input_shape=(lstm_window_size, 2)))
model.add(keras.layers.LSTM(64,
dropout=0.5,
recurrent_dropout=0.5))
model.add(keras.layers.Dense(num_classes, activation='sigmoid'))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])