" TypeError:' Tensor'对象不可迭代" tensorflow Estimator的错误

时间:2017-11-11 17:50:27

标签: python api tensorflow dataset tensorflow-datasets

我有一个程序生成的(无限)数据源,我试图将其用作高级Tensorflow Estimator的输入,以训练基于图像的3D物体探测器。

我设置了数据集,就像在Tensorflor Estimator Quickstart中一样,我的dataset_input_fn会返回一组功能和标签Tensor' s,就像{TypeError: 'Tensor' object is not iterable.一样。 3}}函数指定,以及Estimator.train如何,但我在尝试调用train函数时遇到错误:

def data_generator(): """ Generator for image (features) and ground truth object positions (labels) Sample an image and object positions from a procedurally generated data source """ while True: source.step() # generate next data point object_ground_truth = source.get_ground_truth() # list of 9 floats cam_img = source.get_cam_frame() # image (224, 224, 3) yield (cam_img, object_ground_truth) def dataset_input_fn(): """ Tensorflow `Dataset` object from generator """ dataset = tf.data.Dataset.from_generator(data_generator, (tf.uint8, tf.float32), \ (tf.TensorShape([224, 224, 3]), tf.TensorShape([9]))) dataset = dataset.batch(16) iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next() return features, labels def main(): """ Estimator [from Keras model](https://www.tensorflow.org/programmers_guide/estimators#creating_estimators_from_keras_models) Try to call `est_vgg.train()` leads to the error """ .... est_vgg16 = tf.keras.estimator.model_to_estimator(keras_model=keras_vgg16) est_vgg16.train(input_fn=dataset_input_fn, steps=10) ....

我做错了什么?

Traceback (most recent call last):
  File "./rock_detector.py", line 155, in <module>
    main()
  File "./rock_detector.py", line 117, in main
    est_vgg16.train(input_fn=dataset_input_fn, steps=10)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 145, in model_fn
    labels)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 92, in _clone_and_build_model
    keras_model, features)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/keras/_impl/keras/estimator.py", line 58, in _create_ordered_io
    for key in estimator_io_dict:
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 505, in __iter__
    raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.

以下是tutorial shows

(注意:事情的名称与此问题不同)

这是堆栈跟踪:

// when the user clicks on like
    $('.icon-appear').on('click', function(){
        var postid = $(this).data('id');
        $post = $(this);

        $.ajax({
            url: 'user_profiles.php',
            type: 'post',
            data: {
                'liked': 1,
                'postid': postid
            },
            success: function(response){
                $post.parent().find('span.likes_count').text(response);
                $post.siblings().removeClass('anim');
                $post.addClass('anim');
                $post.addClass('hide');
                $post.siblings().removeClass('hide');

            }
        });
    });
});

1 个答案:

答案 0 :(得分:5)

让你的输入函数返回这样的特征字典:

def dataset_input_fn():
  ...
  features, labels = iterator.get_next()
  return {'image': features}, labels