贝叶斯神经网络实现低准确率和无学习

时间:2021-03-30 11:09:38

标签: tensorflow machine-learning bayesian-networks tensorflow-probability

我想对贝叶斯神经网络使用 Tensorflow 概率。对于 MNIST 数据集的单层网络,它正在工作。现在我想为 CIFAR10 数据集创建一个 BNN。我在网上找到了 AlexNet 并尝试将其转换为 BNN,但不起作用。 在摆弄它之后,它在训练集上的准确率高达 47%,然后又下降到 10%,忘记了一切。在测试集上它低了大约 10%。 (运行这里示例的代码时,它的表现没有以前那么糟糕,但 47% 的准确率仍然很低,而且这种不稳定是不切实际的。)

原始贝叶斯 AlexNet

辍学似乎是不行的。使用 Conv2DReparameterization 层代替 Conv2DFlipout 有助于提高准确性。

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
import tensorflow_probability as tfp

def normalize_image(image, label):
    # Normalizes images: `uint8` -> `float32`.
    tf.cast(image, tf.float32) / 255.0
    # resize image
    image = tf.image.resize(image, (227,227))
    return image, label

ds_train, ds_test = tfds.load('cifar10', split=['train', 'test'], shuffle_files=True, as_supervised=True)
ds_train = ds_train.map(normalize_image, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.map(normalize_image, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache().shuffle(1000).batch(100).prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.batch(100).cache().prefetch(tf.data.AUTOTUNE)
n_batches=len(ds_train)

kl_fn = lambda q, p, _: tfp.distributions.kl_divergence(q, p) / 100 # batch size

model = keras.models.Sequential([
    tfp.layers.Convolution2DReparameterization(filters=96, kernel_size=(11,11), strides=(4,4), activation='relu', input_shape=(227,227,3), kernel_divergence_fn = kl_fn),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
    tfp.layers.Convolution2DReparameterization(filters=256, kernel_size=(5,5), strides=(1,1), activation='relu', padding="same", kernel_divergence_fn = kl_fn),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
    tfp.layers.Convolution2DReparameterization(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same", kernel_divergence_fn = kl_fn),
    keras.layers.BatchNormalization(),
    tfp.layers.Convolution2DReparameterization(filters=384, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same", kernel_divergence_fn = kl_fn),
    keras.layers.BatchNormalization(),
    tfp.layers.Convolution2DReparameterization(filters=256, kernel_size=(1,1), strides=(1,1), activation='relu', padding="same", kernel_divergence_fn = kl_fn),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
    keras.layers.Flatten(),
    tfp.layers.DenseFlipout(4096, activation='relu', kernel_divergence_fn = kl_fn),
    #keras.layers.Dropout(0.5),
    tfp.layers.DenseFlipout(4096, activation='relu', kernel_divergence_fn = kl_fn),
    #keras.layers.Dropout(0.5),
    tfp.layers.DenseFlipout(10, activation='softmax', kernel_divergence_fn = kl_fn)
])

optimizer=tf.optimizers.SGD(lr=1e-4)
model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

model.fit(ds_train, epochs=100, validation_data=ds_test, validation_freq=1, shuffle = True, steps_per_epoch = n_batches)

训练结果:

Epoch 1/100
500/500 [==============================] - 71s 125ms/step - loss: 1426290.4304 - accuracy: 0.1615 - val_loss: 1425954.2500 - val_accuracy: 0.1854
Epoch 2/100
500/500 [==============================] - 57s 114ms/step - loss: 1425846.6692 - accuracy: 0.2494 - val_loss: 1425586.8750 - val_accuracy: 0.2185
Epoch 3/100
500/500 [==============================] - 57s 113ms/step - loss: 1425516.2665 - accuracy: 0.2598 - val_loss: 1425274.3750 - val_accuracy: 0.2343
Epoch 4/100
500/500 [==============================] - 57s 114ms/step - loss: 1425212.5324 - accuracy: 0.2742 - val_loss: 1424987.2500 - val_accuracy: 0.2320
Epoch 5/100
500/500 [==============================] - 57s 114ms/step - loss: 1424925.1495 - accuracy: 0.2765 - val_loss: 1424694.2500 - val_accuracy: 0.2383
Epoch 6/100
500/500 [==============================] - 56s 112ms/step - loss: 1424641.7445 - accuracy: 0.2766 - val_loss: 1424418.5000 - val_accuracy: 0.2346
Epoch 7/100
500/500 [==============================] - 57s 114ms/step - loss: 1424363.5382 - accuracy: 0.2854 - val_loss: 1424144.1250 - val_accuracy: 0.2342
...
Epoch 100/100
500/500 [==============================] - 56s 112ms/step - loss: 1399679.6794 - accuracy: 0.2209 - val_loss: 1399480.8750 - val_accuracy: 0.1943

摆弄着性能更好的“Alexnet”

减少层数有助于提高准确性。

model = keras.models.Sequential([
    tfp.layers.Convolution2DReparameterization(filters=96, kernel_size=(11,11), strides=(4,4), activation='relu', input_shape=(227,227,3), kernel_divergence_fn = kl_fn),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
    tfp.layers.Convolution2DReparameterization(filters=384, kernel_size=(3,3), strides=(1,1), activation='relu', padding="same", kernel_divergence_fn = kl_fn),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2)),
    keras.layers.Flatten(),
    tfp.layers.DenseFlipout(10, activation='softmax', kernel_divergence_fn = kl_fn)
])
optimizer=tf.optimizers.SGD(lr=1e-3)

训练结果:

Epoch 1/100
500/500 [==============================] - 24s 42ms/step - loss: 25728.3201 - accuracy: 0.2082 - val_loss: 25688.4062 - val_accuracy: 0.2169
Epoch 2/100
500/500 [==============================] - 20s 40ms/step - loss: 25663.4967 - accuracy: 0.2992 - val_loss: 25633.4746 - val_accuracy: 0.2390
...
Epoch 64/100
500/500 [==============================] - 21s 42ms/step - loss: 22666.4894 - accuracy: 0.4792 - val_loss: 22631.2578 - val_accuracy: 0.3829
...
Epoch 99/100
500/500 [==============================] - 20s 40ms/step - loss: 21024.2983 - accuracy: 0.4577 - val_loss: 20989.7246 - val_accuracy: 0.3448
Epoch 100/100
500/500 [==============================] - 20s 39ms/step - loss: 20977.7475 - accuracy: 0.4616 - val_loss: 20943.1758 - val_accuracy: 0.3453

使用 Adam 作为优化器时,性能通常会更差。它确实快速达到峰值精度,但也很快崩溃。使用 SGD 作为优化器,学习速度通常较慢,峰值准确率最高可提高 2%。仍然准确度通常会下降。在良好的配置下,它可以保持不变(或非常缓慢地下降)。

我注意到,在训练过程中,它通常以高精度开始一个 epoch,然后下降几个百分点,或者以比上一个 epoch 低的精度开始,然后又回到上一个 epoch 的百分比。所以在一个时期内,准确度在两个方向上都有很大的变化。

有人知道我的代码有什么问题吗?

0 个答案:

没有答案