在具有小特征集的数据集上使用tf.keras创建的DNN的低精度

时间:2018-06-03 08:42:32

标签: tensorflow keras deep-learning

total train data record: 460000

total cross-validation data record: 89000

number of output class: 392

tensorflow 1.8.0 CPU installation

每个数据记录有26个特征,其中25个是数字,一个是分类,其中一个热编码为19个附加特征。最初,并非每个数据记录都存在所有特征值。我已经使用avg来填充缺少的浮点类型功能和最常用的缺少int类型功能的值。输出可以是标记为0到391的392个类别中的任何一个。

最后,所有功能都通过StandardScaler()

传递

这是我的模特:

output_class = 392
X_train, X_test, y_train, y_test = get_data()

# y_train and y_test contains int from 0-391    
# Make y_train and y_test categorical
y_train = tf.keras.utils.to_categorical(y_train, unique_dtc_count)
y_test = tf.keras.utils.to_categorical(y_test, unique_dtc_count)

# Convert to float type
y_train = y_train.astype(np.float32)
y_test = y_test.astype(np.float32)

# tf.enable_eager_execution()  # turned off to use rmsprop optimizer

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(400, activation=tf.nn.relu, input_shape= 
(44,)))
model.add(tf.keras.layers.Dense(40000, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(392, activation=tf.nn.softmax))

model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

import logging
logging.getLogger().setLevel(logging.INFO)
model.fit(X_train, y_train, epochs=3)

loss, acc = model.evaluate(X_test, y_test)
print('Accuracy', acc)

但是这个模型在训练和测试数据上的准确率只有28%。我应该在这里做些什么才能在训练和测试数据上获得良好的准确性?我应该更广泛,更深入吗?或者我应该考虑采取更多功能吗?

注意:数据集中共有400个独特的功能。但大多数功能仅在5到10个数据记录中随机出现。某些功能与其他数据记录无关。我根据数据记录中的领域知识和频率选择了26个特征。

任何建议表示赞赏。感谢。

编辑:我忘了在原帖中添加这个,@ Neb建议一个不太广泛的网络,我实际上尝试了这个。我的第一个模型是[44,400,400,392]层。它让我在训练和测试方面的准确率达到了30%左右。

1 个答案:

答案 0 :(得分:1)

你的模型太宽了。您在第一个隐藏层中有 400 个节点,在第二个层中有 40.000 ,总计400 * 44 + 40.000 * 400 + 392 * 400 = 16.174 .400 参数。但是,您只输入了44个功能!

正因为如此,您的网络能够检测输入中最小的,最难以察觉的变化,最后它将它们视为有价值的信息而不是噪声。我很确定如果你长时间离开你的网络训练(这里我只看到3个纪元),它最终会过度拟合你的训练集。

你有一些解决方案:

  1. 减少每个级别的节点数。您还可以尝试添加1或2个新图层。可能的结构可能是pythonnet

  2. 实施回归。你有多种方法可以做到这一点:

    • 限制网络参数的范围
    • 实施Dropout
    • 实现批量标准化(已知具有小的正则化效应)
  3. 使用Adam Optimizer而不是RMSprop

  4. 如果您的功能有些相关,您可以尝试使用CNN而不是完全连接的网络。
  5. 然后,为了改善概括,您可以:

    1. 探索数据集以查找异常值并将其删除。异常值是一种可能会混淆网络或不传达任何其他信息的示例。
    2. “随机”初始化您的参数,例如使用Xavier的初始化
    3. 最后,我会说:你真的需要392课吗?你能合并一些吗?