Tensorflow估算器错误:'Tensor'对象不可迭代

时间:2018-02-07 06:32:38

标签: python tensorflow-datasets tensorflow-estimator

我的代码使用keras模型和tf数据集,并从磁盘加载图像文件。 运行此代码时:

import tensorflow as tf
import os

# 函数的功能时将filename对应的图片文件读进来,并缩放到统一的大小
def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  #image_string=tf.gfile.FastGFile(filename).read()
  image_decoded = tf.image.decode_png(image_string)

  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

# 图片文件的列表
pics=os.listdir("D:/kaggle/flower/data/train/daisy")
print(pics)
filenames = tf.constant(["D:/kaggle/flower/data/train/daisy/"+e for e in pics])
# label[i]就是图片filenames[i]的label
labels = tf.constant([0]*len(pics))

# 此时dataset中的一个元素是(filename, label)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)


# 此时dataset中的一个元素是(image_resized, label)

def from_dataset(ds):
    return lambda: ds.make_one_shot_iterator().get_next()

# 此时dataset中的一个元素是(image_resized_batch, label_batch)
dataset = dataset.shuffle(buffer_size=1000).batch(32).repeat()
print(dataset.output_shapes)
# iterator = dataset.make_one_shot_iterator()
# one_element = iterator.get_next()

keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
                          loss='categorical_crossentropy',
                          metric='accuracy')
# Create an Estimator from the compiled Keras model. Note the initial model
# state of the keras model is preserved in the created Estimator.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)

est_inception_v3.train(input_fn=from_dataset(dataset),steps=10)
eval_result = est_inception_v3.evaluate(input_fn=from_dataset(dataset))
print(eval_result)

我收到以下错误:

Traceback (most recent call last):
  File "C:/Users/lxm1042642197/PycharmProjects/models/samples/dt.py", line 54, in <module>
    est_inception_v3.train(input_fn=from_dataset(dataset),steps=10)

  File "C:\Anaconda3\lib\site-packages\tensorflow\python\estimator\estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)

  File "C:\Anaconda3\lib\site-packages\tensorflow\python\estimator\estimator.py", line 711, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "C:\Anaconda3\lib\site-packages\tensorflow\python\estimator\estimator.py", line 694, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)

  File "C:\Anaconda3\lib\site-packages\tensorflow\python\keras\_impl\keras\estimator.py", line 145, in model_fn
    labels)

  File "C:\Anaconda3\lib\site-packages\tensorflow\python\keras\_impl\keras\estimator.py", line 92, in _clone_and_build_model
    keras_model, features)
  File "C:\Anaconda3\lib\site-packages\tensorflow\python\keras\_impl\keras\estimator.py", line 58, in _create_ordered_io
    for key in estimator_io_dict:
  File "C:\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 505, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

0 个答案:

没有答案