Tensorflow数据集和网络的形状不匹配

时间:2020-04-29 12:16:32

标签: python tensorflow keras

在使用Tensorflow 2定义非常简单的网络时,我遇到了与形状有关的错误。

我的代码是:

import tensorflow as tf
import pandas as pd

data = pd.read_csv('data.csv')

target = data.pop('result')
target = tf.keras.utils.to_categorical(target.values, num_classes=3)
data_set = tf.data.Dataset.from_tensor_slices((data.values, target))

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=data.shape[1:]),
    tf.keras.layers.Dense(12, activation='relu'),
    tf.keras.layers.Dense(3, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy')
model.fit(data_set, epochs=5)

对fit()的调用引发以下错误:

ValueError: Input 0 of layer sequential is incompatible with the layer: expected axis -1 of input shape to have value 12 but received input with shape [12, 1]

遍历代码:

  1. 输入的CSV文件有十三列-最后一列是标签
  2. 这将转换为3位一键编码
  3. 数据集由两个张量构成-一个为(12,)形状,另一个为(3,)形状
  4. 网络输入层将其期望形状定义为值数据形状,而忽略了第一个轴(批次大小)

对于为什么网络的数据形状和预期的数据形状不匹配,我感到很困惑-尤其是因为后者是通过引用前者来定义的。

1 个答案:

答案 0 :(得分:1)

在数据集的末尾添加.batch()

data_set = tf.data.Dataset.from_tensor_slices((data.values, target)).batch(8)