Keras自定义激活功能(不训练)

时间:2018-08-31 19:18:11

标签: python tensorflow keras deep-learning

我正在尝试使用Keras后端或Tensorflow中的函数来实现自己的激活功能,但是我很难使该函数正确学习。

我的第一种方法是重建现有的激活功能(ELU),以查看我自己的激活功能是否存在问题,但是即使重建功能也无法像激活功能内置到Keras或Tensorflow中一样进行训练。

Tensorflow函数:

console.log()

Keras函数:

def custom_activation(x):
    cond = tf.greater(x, tf.constant(0.0))
    return tf.where(cond,
                    x,
                    tf.subtract(tf.exp(x), tf.constant(1.0)))

我正在使用mnist数据集和一个简单的8层完全连接的网络,每层中有128个节点来测试我的激活功能。该网络使用内置的ELU功能可以略微学习,但是使用自定义的Keras或Tensorflow函数,损耗立即接近零,并且精度完全没有提高。

我想念什么?

对于Keras函数,我遵循How do you create a custom activation function with Keras?;对于Tensorflow,我遵循this post


完整代码(用于复制/粘贴):

Keras的ELU (正常工作)

def custom_activation(x):
    cond = K.greater(x, 0)
    return K.switch(cond, x, K.exp(x) - 1)

get_custom_objects().update({'custom_activation': Activation(custom_activation)})

Keras中的自定义ELU

from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], 28*28)
x_test = x_test.reshape(x_test.shape[0], 28*28)

from keras.utils import to_categorical

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

from keras import Sequential
from keras.layers import Dense, Activation, Dropout
from keras.optimizers import SGD

model = Sequential([
    Dense(128, input_shape=x_train.shape[1:]),
    Activation('elu'),
    Dense(128),
    Activation('elu'),
    Dense(128),
    Activation('elu'),
    Dense(128),
    Activation('elu'),
    Dense(128),
    Activation('elu'),
    Dense(128),
    Activation('elu'),
    Dense(128),
    Activation('elu'),
    Dense(128),
    Activation('elu'),
    Dense(10),
    Activation('sigmoid')
])

model.compile(SGD(lr=0.01), loss='categorical_crossentropy', metrics=['accuracy'])

model.fit(x=x_train, y=y_train,
          validation_data=[x_test, y_test],
          batch_size=64, epochs=5)
使用Keras API在Tensorflow中

自定义ELU

from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(x_train.shape[0], 28*28)
x_test = x_test.reshape(x_test.shape[0], 28*28)

from keras.utils import to_categorical

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

from keras import Sequential
from keras.layers import Dense, Activation, Dropout
from keras.optimizers import SGD
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects

def custom_activation(x):
    cond = K.greater(x, 0)
    return K.switch(cond, x, K.exp(x) - 1)

get_custom_objects().update({'custom_activation': Activation(custom_activation)})

model = Sequential([
    Dense(128, input_shape=x_train.shape[1:]),
    Activation(custom_activation),
    Dense(128),
    Activation(custom_activation),
    Dense(128),
    Activation(custom_activation),
    Dense(128),
    Activation(custom_activation),
    Dense(128),
    Activation(custom_activation),
    Dense(128),
    Activation(custom_activation),
    Dense(128),
    Activation(custom_activation),
    Dense(128),
    Activation(custom_activation),
    Dense(10),
    Activation('sigmoid')
])

model.compile(SGD(lr=0.01), loss='categorical_crossentropy', metrics=['accuracy'])

model.fit(x=x_train, y=y_train,
          validation_data=[x_test, y_test],
          batch_size=64, epochs=5)

1 个答案:

答案 0 :(得分:2)

如果在model.get_weights()的情况下打印出custom_activation,则应该看到权重全部为nan。这就是为什么准确性没有改善的原因。

原因是K.exp(x)inf左右变为x > 88(而MNIST数据集包含从0到255的值)。结果,在通过0 * inf = nan进行梯度传播期间将遇到K.switch()计算。也许检查一下related TF issue以获得更多详细信息。似乎K.switch()(或等效的tf.where())不够聪明,无法弄清楚只有在自定义激活中只有K.exp(x)才需要x < 0的事实。

我不是TensorFlow方面的专家,但是我猜想内置ELU激活(称为tf.nn.elu)可以正常工作的原因是因为tf.nn.elu有自己的渐变操作。 x >= 0x < 0的分支是handled inside the gradient op,而不是将tf.exp()tf.where() ops的梯度相乘,因此可以避免0 * inf = nan计算

要解决该问题,您可以在训练之前对数据进行标准化,

x_train = x_train.reshape(x_train.shape[0], 28*28) / 255.
x_test = x_test.reshape(x_test.shape[0], 28*28) / 255.

或在进行x之前对K.exp()进行上限运算,因为当K.exp(x)大于0时我们不需要知道x的实际值。

def custom_activation(x):
    cond = K.greater(x, 0)
    return K.switch(cond, x, K.exp(K.minimum(x, 0.)) - 1)