我的代码使用keras模型和tf数据集,并从磁盘加载图像文件。 运行此代码时:
import tensorflow as tf
import os
# 函数的功能时将filename对应的图片文件读进来,并缩放到统一的大小
def _parse_function(filename, label):
image_string = tf.read_file(filename)
#image_string=tf.gfile.FastGFile(filename).read()
image_decoded = tf.image.decode_png(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
# 图片文件的列表
pics=os.listdir("D:/kaggle/flower/data/train/daisy")
print(pics)
filenames = tf.constant(["D:/kaggle/flower/data/train/daisy/"+e for e in pics])
# label[i]就是图片filenames[i]的label
labels = tf.constant([0]*len(pics))
# 此时dataset中的一个元素是(filename, label)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
# 此时dataset中的一个元素是(image_resized, label)
def from_dataset(ds):
return lambda: ds.make_one_shot_iterator().get_next()
# 此时dataset中的一个元素是(image_resized_batch, label_batch)
dataset = dataset.shuffle(buffer_size=1000).batch(32).repeat()
print(dataset.output_shapes)
# iterator = dataset.make_one_shot_iterator()
# one_element = iterator.get_next()
keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)
est_inception_v3.train(input_fn=from_dataset(dataset),steps=10)
eval_result = est_inception_v3.evaluate(input_fn=from_dataset(dataset))
print(eval_result)
我收到以下错误:
Traceback (most recent call last):
File "C:/Users/lxm1042642197/PycharmProjects/models/samples/dt.py", line 54, in <module>
est_inception_v3.train(input_fn=from_dataset(dataset),steps=10)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\estimator\estimator.py", line 711, in _train_model
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\estimator\estimator.py", line 694, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\keras\_impl\keras\estimator.py", line 145, in model_fn
labels)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\keras\_impl\keras\estimator.py", line 92, in _clone_and_build_model
keras_model, features)
File "C:\Anaconda3\lib\site-packages\tensorflow\python\keras\_impl\keras\estimator.py", line 58, in _create_ordered_io
for key in estimator_io_dict:
File "C:\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 505, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.