我正在尝试训练一种用于从图形中检测物体的模型。我正在使用tensorflow。我已经使用Google为QuickDraw数据集提供的input_fn创建了一个函数。但是我在运行该功能时遇到了上述错误。该函数的代码为:
def input_func():
"""
The input function for estimator
Returns:
Dictionary of features
Target labels
"""
dataset = tf.data.Dataset.list_files(tfrecord_path)
if mode == tf.estimator.ModeKeys.TRAIN:
dataset = dataset.shuffle(buffer_size = 10)
dataset = dataset.repeat()
dataset = dataset.interleave(
tf.data.TFRecordDataset,
cycle_length=10,
block_length=1
)
dataset = dataset.map(
functools.partial(parse_tfexample,mode = mode),
num_parallel_calls=10
)
dataset = dataset.prefetch(1000000)
if mode == tf.estimator.ModeKeys.TRAIN:
dataset.shuffle(buffer_size=1000000)
dataset = dataset.padded_batch(
batch_size, padded_shapes = dataset.output_shapes
)
features, labels = dataset.make_one_shot_iterator().get_next()
return features, labels
我遇到以下错误:
Traceback (most recent call last):
File "A:\Code\Machine Learning\Software Engineering project\Quick Draw\Train_Model.py", line 298, in <module>
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
_sys.exit(main(argv))
File "A:\Code\Machine Learning\Software Engineering project\Quick Draw\Train_Model.py", line 209, in main
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\training.py", line 471, in train_and_evaluate
return executor.run()
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\training.py", line 610, in run
return self.run_local()
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\training.py", line 711, in run_local
saving_listeners=saving_listeners)
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 354, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1207, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1234, in _train_model_default
input_fn, model_fn_lib.ModeKeys.TRAIN))
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1075, in _get_features_and_labels_from_input_fn
self._call_input_fn(input_fn, mode))
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\estimator\estimator.py", line 1162, in _call_input_fn
return input_fn(**kwargs)
File "A:\Code\Machine Learning\Software Engineering project\Quick Draw\Train_Model.py", line 64, in input_func
batch_size, padded_shapes = dataset.output_shapes
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 945, in padded_batch
drop_remainder)
File "C:\Users\shind\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\data\ops\dataset_ops.py", line 2505, in __init__
"Batching of padded sparse tensors is not currently supported")
TypeError: Batching of padded sparse tensors is not currently supported
该错误的原因是什么以及如何解决?
答案 0 :(得分:0)
问题出在parse_tfexample函数中。在其中,有一个字典,其中有一个元素键为“绘图”的元素是稀疏张量。因此,我只是使用tf.sparse.to_dense()
将其转换为密集型。这是parse_tfexample的代码:
def parse_tfexample(example,mode):
"""Parse a single tf example"""
features = {
"drawing" : tf.VarLenFeature(dtype=tf.float32),
"shape": tf.FixedLenFeature([2],dtype=tf.int64)
}
if mode != tf.estimator.ModeKeys.PREDICT:
features["class_index"] = tf.FixedLenFeature([1],dtype=tf.int64)
parsed_features = tf.parse_single_example(example,features)
parsed_features["drawing"] = tf.sparse.to_dense(parsed_features["drawing"])
print(parsed_features)
labels = None
if mode != tf.estimator.ModeKeys.PREDICT:
labels = parsed_features["class_index"]
return parsed_features, labels