数据API:ValueError:使用数据集作为输入时,不支持y参数

时间:2020-08-17 18:17:16

标签: python tensorflow tensorflow-datasets

我有45000张大小为224 * 224的图像,以numpy数组存储。这个名为source_arr的数组的形状为45000,224,224,它可以装入内存。

我使用tf.data API将这个数组划分为训练,测试和验证数组,并对它们进行预处理(将灰度标准化并将其转换为3通道RGB)。

我写了一个预处理功能,例如:

def pre_process(x):
     x_norm = (x - mean_Rot_MIP) / Var_Rot_MIP
     # Stacking along the last dimension to avoid having to move channel axis
     x_norm_3ch = tf.stack((x_norm, x_norm, x_norm), axis=-1)
     return x_norm_3ch

X_train_cases_idx.idx包含source_arr中作为训练数据一部分的图像的索引。

我已从source_arr中读取了数据集对象中的相应训练图像,例如:

X_train = tf.data.Dataset.from_tensor_slices([source_arr[i] for i in X_train_cases_idx.idx])

然后我将pre_process函数应用于训练图像,例如 X_train = X_train.map(pre_process)

这是一个多类分类问题,因此我将标签变量y转换为1种热编码,例如:

lb = LabelBinarizer()
y_train = lb.fit_transform(y_train)

X_train和y_train的长度为36000

我在RESNET50上执行model.fit操作,如:

H = model.fit(X_train, y_train, batch_size = BS, validation_data=(X_val, y_val), epochs = NUM_EPOCHS, shuffle =False)

我得到一个错误:

ValueError: `y` argument is not supported when using dataset as input.

我知道我需要将X_train和y_train都作为一个元组传递给Dataset对象。 我该怎么做?

1 个答案:

答案 0 :(得分:1)

您将source_arr和y_train作为numpy数组;因此您可以:

data_set = tf.data.Dataset.from_tensor_slices(  (source_arr , y_train) )

如果您将source_arr和y_train作为tf.dataset:

data_set = tf.data.Dataset.zip( (source_arr , y_train) )