Keras中的连接层使拟合失败

时间:2018-11-12 18:53:29

标签: python tensorflow keras

每当我将两层的输出连接起来时(例如,因为我想在某些输出上使用softmax,而在其余输出上使用另一个激活函数),则网络总是无法学习。

这是一些演示此问题的示例代码:

from tensorflow.keras.layers import Lambda, Input, Dense, Concatenate, Dropout, Reshape, \
                                    Conv2D, Flatten, MaxPooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.losses import mse, categorical_crossentropy, binary_crossentropy
from tensorflow.keras.utils import plot_model, to_categorical
from tensorflow.keras import backend as K
from tensorflow.keras import optimizers

import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
import pygameVisualise as pyvis

# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

no_cls = max(y_train)+1
width = 20

extra_dims = True

image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
y_train = to_categorical(y_train, num_classes=width if extra_dims else no_cls)
y_test = to_categorical(y_test, num_classes=width if extra_dims else no_cls)

hidden_dim = 512
batch_sz = 256
eps = 10

ins = Input(shape=(original_dim,))
x = Dense(hidden_dim)(ins)
cls_pred = Dense(no_cls, activation="softmax")(x)
other    = Dense(width-no_cls)(x)
outs = Concatenate()([cls_pred, other])

encoder = Model(ins, outs if extra_dims else cls_pred, name="encoder")
encoder.summary()

def cust_loss_fn(y_true, y_pred):
    return categorical_crossentropy(y_true[:no_cls], y_pred[:no_cls])

optimiser = optimizers.SGD(lr=0.003, clipvalue=0.1)
encoder.compile(optimizer=optimiser, loss=cust_loss_fn,
                metrics=["accuracy"])

encoder.fit(x_train, y_train,
            batch_size=batch_sz,
            epochs=eps,
            validation_data=(x_test, y_test))

score = encoder.evaluate(x_test, y_test)
print(score)

print(encoder.predict(x_train[0:10]))

使用extra_dims = False,即没有连接层,网络将在10个时间段内始终达到88%的精度。设为True时,网络将保持8%左右的准确性,并且在训练过程中损失不会减少。

我做错什么了吗?

0 个答案:

没有答案