Tensorflow数据集:即使调整大小后,不同形状的张量也无法批处理错误吗?

时间:2020-11-02 12:20:05

标签: python tensorflow computer-vision

tensorflow-datasets模块遇到了一些麻烦。使用stanford_dogs数据集,我将图像调整为[180,180]的大小,但是在训练模型时,从错误消息中,看来tensorflow正在尝试以原始大小加载图像。 / p>

我在做什么错了?

下面复制错误(和错误)的代码。数据集在750mb附近。可以将其复制粘贴到google colab中并运行以进行复制。

import io
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

def _normalize_img(img, label):
    img = tf.cast(img, tf.float32) / 255.
    img = tf.image.resize(img,[180,180])
    return (img, label)
    

train_dataset, test_dataset = tfds.load(name="stanford_dogs", split=['train', 'test'], as_supervised=True)

train_dataset = train_dataset.shuffle(1024).batch(32)
train_dataset = train_dataset.map(_normalize_img)

test_dataset = test_dataset.batch(32)
test_dataset = test_dataset.map(_normalize_img)

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(64,2,padding='same',activation='relu',input_shape=(180,180,3)),
    tf.keras.layers.MaxPooling2D(2),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Conv2D(32,2,padding='same',activation='relu'),
    tf.keras.layers.MaxPooling2D(2),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120,activation='softmax')
])


model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss='sparse_categorical_crossentropy')


history = model.fit(
    train_dataset,
    epochs=5)

失败,并显示错误:

InvalidArgumentError:  Cannot batch tensors with different shapes in component 0. First element had shape [278,300,3] and element 1 had shape [375,500,3].
     [[node IteratorGetNext (defined at <ipython-input-29-15023f95f627>:39) ]] [Op:__inference_train_function_4908]

1 个答案:

答案 0 :(得分:0)

您遇到此错误,因为tf.data.Dataset API无法创建一批具有不同形状的张量。由于批处理函数将返回形状为(batch, height, width, channels)的张量,因此heightwidthchannels的值在整个数据集中必须保持恒定。您可以在Introduction to Tensors guide中详细了解原因。

调整大小后进行批处理将解决您的问题:

train_dataset = train_dataset.shuffle(1024)
train_dataset = train_dataset.map(_normalize_img)
# we batch once every image is the same size
train_dataset = train_dataset.batch(32)