使用tensorflow数据集训练Keras模型

时间:2020-07-12 09:54:07

标签: keras tensorflow2.0 tensorflow-datasets

我是tensorflow的新手,我正在尝试使用resnet50构建图像分类器以对狗的品种数据集进行分类,但是我无法处理tensorflow.dataset。 这是代码

import numpy as np
import pandas as pd 
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.applications import ResNet50

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import metrics
import tensorflow_datasets as tfds




train= tfds.load('stanford_dogs', split= 'train')
test= tfds.load('stanford_dogs', split= 'test')

model = keras.Sequential()
model.add(ResNet50(include_top=False, weights='imagenet', pooling='avg', ))
model.add(BatchNormalization())

model.add(Dense(1024, activation = 'relu'))
model.add(BatchNormalization())

model.add(Dense(120, activation='softmax'))

model.layers[0].trainable = False

model.compile(optimizer = 'adam', loss = keras.losses.sparse_categorical_crossentropy, metrics = ['accuracy'])

model.summary()



model.fit(
    train,
    steps_per_epoch = 100,
    epochs = 30,
    verbose =2,
    validation_data = test
       
)

它给了我这个错误,

KeyError                                  Traceback (most recent call last)
<ipython-input-16-da6b57b304b5> in <module>()
      4     epochs = 30,
      5     verbose =2,
----> 6     validation_data = test
      7 
      8 )

10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise


    KeyError: 'resnet50_input'

测试和训练变量具有下一种类型

type(train)
tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter

type(test)
tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter

数据集存储为tfrecord

!ls /root/tensorflow_datasets/stanford_dogs/0.2.0



dataset_info.json
image.image.json
label.labels.txt
stanford_dogs-test.tfrecord-00000-of-00004
stanford_dogs-test.tfrecord-00001-of-00004
stanford_dogs-test.tfrecord-00002-of-00004
stanford_dogs-test.tfrecord-00003-of-00004
stanford_dogs-train.tfrecord-00000-of-00004
stanford_dogs-train.tfrecord-00001-of-00004
stanford_dogs-train.tfrecord-00002-of-00004
stanford_dogs-train.tfrecord-00003-of-00004

我在google上搜索了一个解决方案,我发现的所有文章都是关于如何将数据集转换为tfrecord,然后读取tfrecord并以此建立输入管道的,但是tensorflow documentation tensorflow_datasets(tfds)定义了可与TensorFlow一起使用的数据集

0 个答案:

没有答案