使用张量流数据集训练模型

时间:2020-06-15 15:48:07

标签: python dataframe tensorflow

我正在尝试通过实现自己的调整后的版本来了解ML模型的此(https://www.kaggle.com/reighns/groupkfold-and-stratified-groupkfold-efficientnet)实现。

使用与示例中相同的数据集,使用以下方法生成tf.dataset对象:

train_dataset_fold_1 = (
    tf.data.Dataset
    .from_tensor_slices((train_paths_fold_1, train_labels_fold_1))
    .map(decode_image, num_parallel_calls=AUTO)
    .map(data_augment, num_parallel_calls=AUTO)
    .repeat()
    .shuffle(512)
    .batch(BATCH_SIZE)
    .prefetch(AUTO))

我要学习的模型是:

def get_model():
    with strategy.scope():
        return Sequential(
            [Conv2D(filters = 64, kernel_size = (5,5),padding = 'Same', activation ='relu', input_shape = (256, 256,3)),
            BatchNormalization(),
            Conv2D(filters = 64, kernel_size = (5,5),padding = 'Same', activation ='relu'),
            MaxPool2D(pool_size=(2,2)),
            Dropout(0.25),
            Conv2D(filters = 64, kernel_size = (3,3),padding = 'Same', activation ='relu'),
            BatchNormalization(),
            Conv2D(filters = 64, kernel_size = (3,3),padding = 'Same', activation ='relu'),
            MaxPool2D(pool_size=(2,2), strides=(2,2)),
            Dropout(0.25),
            Flatten(),
            Dense(256, activation = "relu"),
            Dropout(0.5),
            Dense(2, activation = "softmax")])

但是,当我尝试训练时,出现以下错误:

AttributeError: 'PrefetchDataset' object has no attribute 'ndim'

我正在训练为:

learning_rate_reduction = ReduceLROnPlateau(monitor='val_accuracy', 
                                            patience=3, 
                                            verbose=1, 
                                            factor=0.5, 
                                            min_lr=0.00001)
model_fold_1 = get_model()
history_1 = model_fold_1.fit(train_dataset_fold_1,
                    epochs=EPOCHS,
                    verbose = 1,
                    callbacks=[learning_rate_reduction],
                    steps_per_epoch=STEPS_PER_EPOCH,
                    validation_data=valid_dataset_fold_1)

这似乎与训练数据输入格式有关。

当前,这是一个张量流张量,它是使用decode_image和data_augmentation函数生成的。我认为错误是由于在输入上调用了某些np数组方法而引起的。但是,当我尝试更改解码和扩充功能中的代码时,Tensor加速会丢失。

有什么快速的方法来解决这个问题?

0 个答案:

没有答案