使用“稀疏分类交叉熵”

时间:2020-11-03 01:34:46

标签: python keras deep-learning

我在理解为什么稀疏分类交叉熵不适用于SVHN数据集时遇到问题。

import tensorflow as tf
from scipy.io import loadmat
import numpy as np

train = loadmat('data/train_32x32.mat')
test = loadmat('data/test_32x32.mat')

x_train = train['X']
y_train = train['y']
x_train = x_train.astype('float64')
y_train = y_train.astype('int64')

x_test = test['X']
x_test = x_test.astype('float64')
y_test = test['y']
y_test = y_test.astype('int64')

# reorder data
x_train = np.moveaxis(x_train, -1, 0)
x_test = np.moveaxis(x_test, -1, 0)
def colored_to_gray(x):
    '''
    input shape: n_sample, n_x, x_y, n_channel
    output shape: n_sample, n_x, x_y, 1
    this is a rudementary way of converting a colored image into gray image
    '''
    x = np.mean(x, axis=-1, keepdims=True)
    return x

def normalize_data(x):
    '''
    normalize data so that values are between 0 to 1
    '''
    x = x / 255.0
    return x

x_train = colored_to_gray(x_train)
x_test = colored_to_gray(x_test)

x_train = normalize_data(x_train)
x_test = normalize_data(x_test)
print("Shape of Training Data: {}".format(x_train.shape))
print("Shape of Training Labels: {}".format(y_train.shape))
print("Shape of Testing Data: {}".format(x_test.shape))
print("Shape of Testing Labels: {}".format(y_test.shape))

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense

model = Sequential([
    Flatten(name='Flatten_Input', input_shape=x_train.shape[1:]),
    Dense(units=1024, activation='relu', name='Dense_1'),
    Dense(units=512, activation='relu', name='Dense_2'),
    Dense(units=256, activation='relu', name='Dense_3'),
    Dense(units=32, activation='relu', name='Dense_4'),
    Dense(units=10, activation='softmax', name='Output')
])

opt = tf.keras.optimizers.Adam(learning_rate=0.0001)

model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, epochs=2, batch_size=256)

通过这个model.fit调用,我期望它可以在10节课和训练中工作。取而代之的是,我得到“ nan”作为损失输出,准确度为0。

  • 有人可以解释这里发生了什么吗?
  • 它与输入的大小或输入的类型有关吗?

谢谢

1 个答案:

答案 0 :(得分:0)

null是每个图像仅属于一个类的情况。在SVHN数据集中,它不是,例如,图像Sparse具有3213, 2,它们是多类的。将其更改为1,它应该可以工作。此外,您没有使用“准确性”作为指标。