CNN Uint8数据类型问题

时间:2020-05-23 15:16:47

标签: numpy keras cnn

我正在尝试使用Keras为MNIST创建CNN,但是我的代码存在一些问题。 我主要收到此错误:

    TypeError: Value passed to parameter 'input' has DataType uint8 not in list of allowed values: float16, bfloat16, float32, float64

这是我的代码:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Conv2D, Dropout, MaxPooling2D
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.utils import to_categorical
(Train_Data, Train_Labels), (Test_Data, Test_Labels) = tf.keras.datasets.mnist.load_data()


Train_Data = Train_Data.reshape(60000,28,28,1)
Test_Data = Test_Data.reshape(10000,28,28,1)

def save(model):
    model.save("CNN")
def load(name):
    model = tf.keras.models.load_model(name)

model = keras.Sequential()
model.add(Conv2D(784, kernel_size=3, activation='relu'))
model.add(MaxPooling2D(pool_size=(5,5)))

model.add(Dropout(.2))
model.add(keras.layers.Flatten())
model.add(Dense(25, activation='relu'))
model.add(Dense(10, activation='softmax'))

model.compile(optimzer='adam', loss="mse", metrics=['accuracy'])



model.fit(Train_Data, Train_Labels)

我不知道该怎么办,我们将不胜感激

1 个答案:

答案 0 :(得分:1)

MNIST数据的原始图像的类型为uint8(值在[0,255]范围内),但是在训练CNN之前,您需要对其进行归一化。通常,您需要将其规范化为零附近的统一边界,例如[-0.5,0.5]。您可以通过添加以下行来做到这一点:

Train_Data = Train_Data / 255 - 0.5
Test_Data = Train_Data / 255 - 0.5
相关问题