从张量流数据集获取特征时出错

时间:2019-08-09 08:59:36

标签: python python-3.x tensorflow runtime-error tensorflow-datasets

尝试加载Caltech tensorflow数据集时出现错误。我正在使用tensorflow-datasets GitHub

中的标准代码

错误是这样的:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [204,300,3] and element 1 had shape [153,300,3]. [Op:IteratorGetNextSync]

错误指向第for features in ds_train.take(1)

代码:

ds_train, ds_test = tfds.load(name="caltech101", split=["train", "test"])

ds_train = ds_train.shuffle(1000).batch(128).prefetch(10)
for features in ds_train.take(1):
    image, label = features["image"], features["label"]

1 个答案:

答案 0 :(得分:1)

问题出在以下事实:数据集包含大小可变的图像(请参见数据集描述here)。 Tensorflow只能将具有相同形状的东西组合在一起,因此您首先需要将图像重塑为通用形状(例如,网络的输入形状)或相应地填充它们。

如果要调整大小,请使用tf.image.resize_images

def preprocess(features, label):
  features['image'] = tf.image.resize_images(features['image'], YOUR_TARGET_SIZE)
  # Other possible transformations needed (e.g., converting to float, normalizing to [0,1]
  return features, label

如果要填充,请使用tf.image.pad_to_bounding_box(只需将其替换为上面的preprocess函数并根据需要调整参数)。 通常,对于我所了解的大多数网络,都使用调整大小。

最后,将函数映射到您的数据集:

ds_train = (ds_train
            .map(prepocess)
            .shuffle(1000)
            .batch(128)
            .prefetch(10))

注意:错误代码中的变量形状来自shuffle调用。