TFRecordReader会话关闭后保持文件锁定

时间:2017-09-16 22:39:47

标签: python python-3.x tensorflow

运行此脚本(您需要here中的一些tfrecords):

import os
import shutil
import tempfile
import tensorflow as tf

data_dir = r'/path/to/tfrecords'

def test_generate_tfrecords_from_csv():
    with tempfile.TemporaryDirectory() as tmpdirname:
        filenames = os.listdir(data_dir)
        for f in filenames: # copy tfrecords to tmp dir
            shutil.copy(os.path.join(data_dir, f), os.path.join(tmpdirname, f))
        filenames = sorted([os.path.join(tmpdirname, f) for f in filenames])
        # Create a queue that produces the filenames to read.
        queue = tf.train.string_input_producer(filenames, num_epochs=1,
                                               shuffle=False)
        with tf.Session() as sess:
            sess.run(tf.local_variables_initializer())
            tf.train.start_queue_runners(sess=sess)
            reader = tf.TFRecordReader()
            for j in range(len(filenames)):
                class _Result(object) : pass
                result = _Result()
                result.key, value = reader.read(queue)
                features_dict = tf.parse_single_example(value, features={
                    'label': tf.FixedLenFeature([], tf.string),})
                # the decode call below is needed, if you replace it with
                # result.label = tf.constant(0) no files are locked
                result.label = tf.decode_raw(features_dict['label'],
                                             tf.float32)
                _ = sess.run([result.label]) # files are locked here
        listdir = os.listdir(tmpdirname)
        print(tmpdirname, listdir)
        for f in sorted(listdir):
            os.remove(os.path.join(tmpdirname, f))

test_generate_tfrecords_from_csv()

产地:

C:\Users\MrD\AppData\Local\Temp\tmpf9qejl6m ['img_2013-01-01-00-00.tfrecords', 'img_2013-01-01-00-01.tfrecords', 'img_2013-01-01-00-02.tfrecords']
Traceback (most recent call last):
  File "C:/Users/MrD/.PyCharm2017.2/config/scratches/scratch_49.py", line 50, in <module>
    test_generate_tfrecords_from_csv()
  File "C:/Users/MrD/.PyCharm2017.2/config/scratches/scratch_49.py", line 48, in test_generate_tfrecords_from_csv
    os.remove(os.path.join(tmpdirname, f))
  File "C:\_\Python35\lib\tempfile.py", line 808, in __exit__
    self.cleanup()
  File "C:\_\Python35\lib\tempfile.py", line 812, in cleanup
    _shutil.rmtree(self.name)
  File "C:\_\Python35\lib\shutil.py", line 488, in rmtree
    return _rmtree_unsafe(path, onerror)
  File "C:\_\Python35\lib\shutil.py", line 383, in _rmtree_unsafe
    onerror(os.unlink, fullname, sys.exc_info())
  File "C:\_\Python35\lib\shutil.py", line 381, in _rmtree_unsafe
    os.unlink(fullname)
PermissionError: [WinError 5] Access is denied: 'C:\\Users\\MrD\\AppData\\Local\\Temp\\tmpf9qejl6m\\img_2013-01-01-00-00.tfrecords'

除非我错过锁定文件的TFRecordReader。我该如何正确发布这些文件?

更新

我进一步简化了代码 - 似乎只有在发出对tf.decode_raw的调用时才锁定文件。那个怪人,我希望parse_single_example锁定文件

>>> tf.__version__
'1.0.1'
>>> sys.version
'3.5.2 (v3.5.2:4def2a2901a5, Jun 25 2016, 22:18:55) [MSC v.1900 64 bit (AMD64)]'

1 个答案:

答案 0 :(得分:1)

我已经向tensorflow跟踪器发布了issue - 请参阅那里的等效数据集代码,这也有同样的问题