检查输入时出错:期望flatten_input具有4个维度,但数组的形状为(404,13)

时间:2020-08-04 14:17:17

标签: python machine-learning neural-network

import tensorflow as tf

from tensorflow import keras

import numpy as np

import matplotlib.pyplot as plt

data = keras.datasets.boston_housing

(x_train , y_train) , (x_test , y_test) = data.load_data()

model = keras.Sequential([

    keras.layers.Flatten(input_shape=(28 , 28 )),

    keras.layers.Dense(128 , activation="relu"),

    keras.layers.Dense(10 , activation="softmax")
])

model.compile(optimizer="adam" , loss="sparse_categorical_crossentropy" , metrics=["accuracy"])

model.fit(x_train , y_train ,epochs=5 )

test_loss , test_acc = model.evaluate(x_test , y_test)

print("tested acc: ", test_acc)

2 个答案:

答案 0 :(得分:0)

Flatten层将3D +张量转换为2D。我猜想这个模型是为二进制MNIST设计的,它具有不同的输入形状,并且您尝试在另一个数据集上使用它。波士顿房屋数据集已经有2D输入,因此在这里没有意义。您可以通过更改输入形状来执行此操作,但这没有任何意义:

keras.layers.Flatten(input_shape=(13,)),

所有您需要做的就是删除它,它将运行正常。然后,由于要处理回归问题,因此您必须更改损失函数,指标和最终激活函数。最终的,更正的代码:

import tensorflow as tf

from tensorflow import keras

import numpy as np

import matplotlib.pyplot as plt

data = keras.datasets.boston_housing

(x_train , y_train) , (x_test , y_test) = data.load_data()

model = keras.Sequential([

    keras.layers.Dense(128 , activation="relu"),

    keras.layers.Dense(1)
])

model.compile(optimizer="adam" , loss="mae")

model.fit(x_train , y_train ,epochs=5 )

test_loss = model.evaluate(x_test , y_test)

print("test loss: ", test_loss)
 32/102 [========>.....................] - ETA: 0s - loss: 5.9563
102/102 [==============================] - 0s 682us/sample - loss: 6.7218

test loss:  6.721758244084377

答案 1 :(得分:0)

Flatten用于使2d数据像图像一样平坦,因此基本上是将2d列表转换为1d列表,因此应将Flatten更改为Input

第二个错误是声明输入形状。

input_shape=(28 , 28 )

您声明了28x28,但我想您想拥有28个具有28个特征的样本。这是不变的。要正确执行此操作,请将输入形状定义为灵活的,它将匹配训练和预测中任意数量的样本。我需要做的所有事情就是通过一个样本具有多少功能

input_shape=(28, )

这就是它的样子

model = keras.Sequential([

    keras.layers.Input(input_shape=(28, )),

    keras.layers.Dense(128 , activation="relu"),

    keras.layers.Dense(10 , activation="softmax")
])