尝试加载Caltech tensorflow数据集时出现错误。我正在使用tensorflow-datasets GitHub
中的标准代码错误是这样的:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [204,300,3] and element 1 had shape [153,300,3]. [Op:IteratorGetNextSync]
错误指向第for features in ds_train.take(1)
行
代码:
ds_train, ds_test = tfds.load(name="caltech101", split=["train", "test"])
ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
for features in ds_train.take(1):
image, label = features["image"], features["label"]
答案 0 :(得分:1)
问题出在以下事实:数据集包含大小可变的图像(请参见数据集描述here)。 Tensorflow只能将具有相同形状的东西组合在一起,因此您首先需要将图像重塑为通用形状(例如,网络的输入形状)或相应地填充它们。
如果要调整大小,请使用tf.image.resize_images:
def preprocess(features, label):
features['image'] = tf.image.resize_images(features['image'], YOUR_TARGET_SIZE)
# Other possible transformations needed (e.g., converting to float, normalizing to [0,1]
return features, label
如果要填充,请使用tf.image.pad_to_bounding_box(只需将其替换为上面的preprocess
函数并根据需要调整参数)。
通常,对于我所了解的大多数网络,都使用调整大小。
最后,将函数映射到您的数据集:
ds_train = (ds_train
.map(prepocess)
.shuffle(1000)
.batch(128)
.prefetch(10))
注意:错误代码中的变量形状来自shuffle
调用。