TypeError:填充TF数据集对象时当前不支持填充稀疏张量的批处理

时间:2019-02-25 13:59:54

标签: python tensorflow deep-learning

我正在尝试训练一种用于从图形中检测物体的模型。我正在使用tensorflow。我已经使用Google为QuickDraw数据集提供的input_fn创建了一个函数。但是我在运行该功能时遇到了上述错误。该函数的代码为:

def input_func():
        """ 
        The input function for estimator

        Returns:
            Dictionary of features
            Target labels
        """
        dataset = tf.data.Dataset.list_files(tfrecord_path)
        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset  = dataset.shuffle(buffer_size = 10)
        dataset = dataset.repeat()
        dataset = dataset.interleave(
            tf.data.TFRecordDataset,
            cycle_length=10,
            block_length=1
        )
        dataset = dataset.map(
            functools.partial(parse_tfexample,mode = mode),
            num_parallel_calls=10
        )
        dataset =  dataset.prefetch(1000000)
        if mode == tf.estimator.ModeKeys.TRAIN:
            dataset.shuffle(buffer_size=1000000)
        dataset = dataset.padded_batch(
            batch_size, padded_shapes = dataset.output_shapes
        )
        features, labels = dataset.make_one_shot_iterator().get_next()

        return features, labels

我遇到以下错误:

Traceback (most recent call last):
  File "A:\Code\Machine Learning\Software Engineering project\Quick Draw\Train_Model.py", line 298, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
    _sys.exit(main(argv))
  File "A:\Code\Machine Learning\Software Engineering project\Quick Draw\Train_Model.py", line 209, in main
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\training.py", line 471, in train_and_evaluate
    return executor.run()
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\training.py", line 610, in run
    return self.run_local()
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\training.py", line 711, in run_local
    saving_listeners=saving_listeners)
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 354, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1207, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1234, in _train_model_default
    input_fn, model_fn_lib.ModeKeys.TRAIN))
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1075, in _get_features_and_labels_from_input_fn
    self._call_input_fn(input_fn, mode))
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1162, in _call_input_fn
    return input_fn(**kwargs)
  File "A:\Code\Machine Learning\Software Engineering project\Quick Draw\Train_Model.py", line 64, in input_func
    batch_size, padded_shapes = dataset.output_shapes
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 945, in padded_batch
    drop_remainder)
  File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 2505, in __init__
    "Batching of padded sparse tensors is not currently supported")
TypeError: Batching of padded sparse tensors is not currently supported

该错误的原因是什么以及如何解决?

1 个答案:

答案 0 :(得分:0)

问题出在parse_tfexample函数中。在其中,有一个字典,其中有一个元素键为“绘图”的元素是稀疏张量。因此,我只是使用tf.sparse.to_dense()将其转换为密集型。这是parse_tfexample的代码:

def parse_tfexample(example,mode):
        """Parse a single tf example"""
        features = {
            "drawing" : tf.VarLenFeature(dtype=tf.float32),
            "shape": tf.FixedLenFeature([2],dtype=tf.int64)
        }
        if mode != tf.estimator.ModeKeys.PREDICT:
            features["class_index"] = tf.FixedLenFeature([1],dtype=tf.int64)
        parsed_features = tf.parse_single_example(example,features)
        parsed_features["drawing"] = tf.sparse.to_dense(parsed_features["drawing"])
        print(parsed_features)
        labels = None
        if mode != tf.estimator.ModeKeys.PREDICT:
            labels = parsed_features["class_index"]

        return parsed_features, labels