合并多个张量流数据集?

时间:2019-04-11 04:06:05

标签: tensorflow

我有图像数据集,我应用了几种增强方法,并为每种方法创建了一个单独的数据集,然后将所有数据集连接为一个,并在CNN中使用,但是出现此错误

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args)
    415             attrs=("executor_type", executor_type,
    416                    "config_proto", config),
--> 417             ctx=ctx)
    418       # Replace empty list with None
    419       outputs = outputs or None

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     65     else:
     66       message = e.message
---> 67     six.raise_from(core._status_to_exception(e.code, message), None)
     68   except TypeError as e:
     69     if any(ops._is_keras_symbolic_tensor(x) for x in inputs):

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: logits and labels must have the same first dimension, got logits shape [100352,3] and labels shape [32]
     [[{{node loss/dense_1_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits}}]] [Op:__inference_keras_scratch_graph_2206]

代码:

path_train_ds = tf.data.Dataset.from_tensor_slices(X_train)
label_train_ds = tf.data.Dataset.from_tensor_slices(tf.cast(y_train, tf.int64))
image_train_ds = path_train_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

image_label_train_ds = tf.data.Dataset.zip((image_train_ds, label_train_ds))

def random_flip_up_down(img,label):
    tf_img = tf.image.random_flip_up_down(img,1)
    return tf_img,label

image_label_train_random_flip_up_down=image_label_train_ds.map(random_flip_up_down)

def random_saturation(img,label):
    tf_img = tf.image.random_saturation(img,0.3,0.8,1)
    return tf_img,label

image_label_train_random_saturation =image_label_train_ds.map(random_saturation)

def concet(ds):
  ds0 = ds[0]
  for ds1 in ds[1:]:
     ds0 = ds0.concatenate(ds1)
  return ds0    

ds =[image_label_train_random_flip_up_down,image_label_train_random_saturation]
image_label_ds_aug = concet(ds)
model_4.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
history_4 = model_4.fit(image_label_ds_aug, epochs=8, steps_per_epoch=math.ceil(10000/BATCH_SIZE),
validation_data=(image_label_test_ds),callbacks = [MetricsCheckpoint('logs')])

0 个答案:

没有答案