我是tensorflow的新手,我正在尝试使用resnet50构建图像分类器以对狗的品种数据集进行分类,但是我无法处理tensorflow.dataset。 这是代码
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.applications import ResNet50
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
import tensorflow_datasets as tfds
train= tfds.load('stanford_dogs', split= 'train')
test= tfds.load('stanford_dogs', split= 'test')
model = keras.Sequential()
model.add(ResNet50(include_top=False, weights='imagenet', pooling='avg', ))
model.add(BatchNormalization())
model.add(Dense(1024, activation = 'relu'))
model.add(BatchNormalization())
model.add(Dense(120, activation='softmax'))
model.layers[0].trainable = False
model.compile(optimizer = 'adam', loss = keras.losses.sparse_categorical_crossentropy, metrics = ['accuracy'])
model.summary()
model.fit(
train,
steps_per_epoch = 100,
epochs = 30,
verbose =2,
validation_data = test
)
它给了我这个错误,
KeyError Traceback (most recent call last)
<ipython-input-16-da6b57b304b5> in <module>()
4 epochs = 30,
5 verbose =2,
----> 6 validation_data = test
7
8 )
10 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
966 except Exception as e: # pylint:disable=broad-except
967 if hasattr(e, "ag_error_metadata"):
--> 968 raise e.ag_error_metadata.to_exception(e)
969 else:
970 raise
KeyError: 'resnet50_input'
测试和训练变量具有下一种类型
type(train)
tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter
type(test)
tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter
数据集存储为tfrecord
!ls /root/tensorflow_datasets/stanford_dogs/0.2.0
dataset_info.json
image.image.json
label.labels.txt
stanford_dogs-test.tfrecord-00000-of-00004
stanford_dogs-test.tfrecord-00001-of-00004
stanford_dogs-test.tfrecord-00002-of-00004
stanford_dogs-test.tfrecord-00003-of-00004
stanford_dogs-train.tfrecord-00000-of-00004
stanford_dogs-train.tfrecord-00001-of-00004
stanford_dogs-train.tfrecord-00002-of-00004
stanford_dogs-train.tfrecord-00003-of-00004
我在google上搜索了一个解决方案,我发现的所有文章都是关于如何将数据集转换为tfrecord,然后读取tfrecord并以此建立输入管道的,但是tensorflow documentation 说 tensorflow_datasets(tfds)定义了可与TensorFlow一起使用的数据集