我有图像数据集,我应用了几种增强方法,并为每种方法创建了一个单独的数据集,然后将所有数据集连接为一个,并在CNN中使用,但是出现此错误
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args)
415 attrs=("executor_type", executor_type,
416 "config_proto", config),
--> 417 ctx=ctx)
418 # Replace empty list with None
419 outputs = outputs or None
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code, message), None)
68 except TypeError as e:
69 if any(ops._is_keras_symbolic_tensor(x) for x in inputs):
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [100352,3] and labels shape [32]
[[{{node loss/dense_1_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]] [Op:__inference_keras_scratch_graph_2206]
代码:
path_train_ds = tf.data.Dataset.from_tensor_slices(X_train)
label_train_ds = tf.data.Dataset.from_tensor_slices(tf.cast(y_train, tf.int64))
image_train_ds = path_train_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
image_label_train_ds = tf.data.Dataset.zip((image_train_ds, label_train_ds))
def random_flip_up_down(img,label):
tf_img = tf.image.random_flip_up_down(img,1)
return tf_img,label
image_label_train_random_flip_up_down=image_label_train_ds.map(random_flip_up_down)
def random_saturation(img,label):
tf_img = tf.image.random_saturation(img,0.3,0.8,1)
return tf_img,label
image_label_train_random_saturation =image_label_train_ds.map(random_saturation)
def concet(ds):
ds0 = ds[0]
for ds1 in ds[1:]:
ds0 = ds0.concatenate(ds1)
return ds0
ds =[image_label_train_random_flip_up_down,image_label_train_random_saturation]
image_label_ds_aug = concet(ds)
model_4.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
history_4 = model_4.fit(image_label_ds_aug, epochs=8, steps_per_epoch=math.ceil(10000/BATCH_SIZE),
validation_data=(image_label_test_ds),callbacks = [MetricsCheckpoint('logs')])