训练Keras模型时,使用稀疏数组表示标签

时间:2019-06-28 15:06:14

标签: python tensorflow keras sparse-matrix

我正在构建Keras模型以将数据分类为3000个不同的类别,我的训练数据包含大量样本,因此在对训练的输出进行一次热编码后,数据非常大(item_count * 3000 *浮动+输入数据大小) 有可能建议将稀疏数组作为训练数据的输出传递给keras吗?

1 个答案:

答案 0 :(得分:1)

您可以通过使用sparse_categorical_crossentropy损失函数来稀疏地表示您的基本事实。

# assuming get_model() returns your Keras model with an output_shape == [None, 3000]
# assuming get_data() returns training data, with y_train having shape == [num_samples]

x_train, y_train = get_data()
model = get_model()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=16)