读取TFRecord文件时无限循环

时间:2017-01-18 21:34:41

标签: python tensorflow eval binaryfiles

我正在尝试读取TFRecord文件,但每当我尝试评估刚刚从文件数据生成的张量时,我的终端都会崩溃。我正在使用RNN(循环神经网络),因此我尝试使用序列数据进行工作。我复制了here中的大部分代码,但我已经添加了自己的内容来尝试包含TFRecordReader。

import tensorflow as tf
import tempfile
import os

sequences = [[1, 2, 3], [4, 5, 1], [1, 2]]
label_sequences = [[0, 1, 0], [1, 0, 0], [1, 1]]

def make_example(sequence, labels):
    # The object we return
    ex = tf.train.SequenceExample()
    # A non-sequential feature of our example
    sequence_length = len(sequence)
    ex.context.feature["length"].int64_list.value.append(sequence_length)
    # Feature lists for the two sequential features of our example
    fl_tokens = ex.feature_lists.feature_list["tokens"]
    fl_labels = ex.feature_lists.feature_list["labels"]
    for token, label in zip(sequence, labels):
        fl_tokens.feature.add().int64_list.value.append(token)
        fl_labels.feature.add().int64_list.value.append(label)
    return ex

def parse_example(filename_queue):

    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)
    print(example)

    #example = filename_queue

    context_features = {
    "length": tf.FixedLenFeature([], dtype=tf.int64)
    }
    sequence_features = {
        "tokens": tf.FixedLenSequenceFeature([], dtype=tf.int64),
        "labels": tf.FixedLenSequenceFeature([], dtype=tf.int64)
    }

    # Parse the example
    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=example,
        context_features=context_features,
        sequence_features=sequence_features
    )

    return context_parsed, sequence_parsed

if __name__ == '__main__':
    generated_file = ""
    #################################
    # Write all examples into a TFRecords file
    #################################
    with tempfile.NamedTemporaryFile(dir=".", delete=False) as fp:
        writer = tf.python_io.TFRecordWriter(fp.name)
        generated_file = fp.name
        for sequence, label_sequence in zip(sequences, label_sequences):
            ex = make_example(sequence, label_sequence)
            writer.write(ex.SerializeToString())
        writer.close()

    #################################
    # Read contents of TFRecord file
    #################################


    filename = os.path.join(os.getcwd(), generated_file)
    print(filename)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        filename_queue = tf.train.string_input_producer([filename])
        context_parsed, sequence_parsed = parse_example(filename_queue)
        print(context_parsed)
        print(sequence_parsed)

        # THIS IS WHERE I NEED HELP
        # terminal freezes if the following lines are uncommented

        #print(context_parsed["length"].eval()) 
        #print(sequence_parsed["tokens"].eval()) 
        #print(sequence_parsed["labels"].eval()) 

我仍然熟悉张量流,所以我很欣赏我做错的解释(不仅仅是代码修复)所以我将来也不会犯类似的错误。谢谢!

-

编辑:好的我已经改变了以上内容:

with tf.Session() as sess:
    tf.train.start_queue_runners(sess=sess) # I included this line
    sess.run(tf.initialize_all_variables()) 

但是这给了我以下错误:

错误:tensorflow:QueueRunner中的异常:尝试使用已关闭的会话。

我不确定这有什么问题。根据{{​​3}},雅罗斯拉夫的评论是正确的,但我不确定我做错了什么。我已经尝试了几种不同格式的代码,但我不确定它是什么,我不知道。

-

EDIT2:好的,我明白了。显然我需要生成一个协调器并将其传递给队列运行器以使其工作。这是我的最终代码:

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(coord=coord)
        context_parsed, sequence_parsed = parse_example(filename_queue)
        for i in range(3):
            v1,v2,v3 = sess.run([context_parsed["length"], sequence_parsed["tokens"], sequence_parsed["labels"]])
            print(v1,v2,v3)

0 个答案:

没有答案