我正在尝试读取TFRecord文件,但每当我尝试评估刚刚从文件数据生成的张量时,我的终端都会崩溃。我正在使用RNN(循环神经网络),因此我尝试使用序列数据进行工作。我复制了here中的大部分代码,但我已经添加了自己的内容来尝试包含TFRecordReader。
import tensorflow as tf
import tempfile
import os
sequences = [[1, 2, 3], [4, 5, 1], [1, 2]]
label_sequences = [[0, 1, 0], [1, 0, 0], [1, 1]]
def make_example(sequence, labels):
# The object we return
ex = tf.train.SequenceExample()
# A non-sequential feature of our example
sequence_length = len(sequence)
ex.context.feature["length"].int64_list.value.append(sequence_length)
# Feature lists for the two sequential features of our example
fl_tokens = ex.feature_lists.feature_list["tokens"]
fl_labels = ex.feature_lists.feature_list["labels"]
for token, label in zip(sequence, labels):
fl_tokens.feature.add().int64_list.value.append(token)
fl_labels.feature.add().int64_list.value.append(label)
return ex
def parse_example(filename_queue):
reader = tf.TFRecordReader()
_, example = reader.read(filename_queue)
print(example)
#example = filename_queue
context_features = {
"length": tf.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
"tokens": tf.FixedLenSequenceFeature([], dtype=tf.int64),
"labels": tf.FixedLenSequenceFeature([], dtype=tf.int64)
}
# Parse the example
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=example,
context_features=context_features,
sequence_features=sequence_features
)
return context_parsed, sequence_parsed
if __name__ == '__main__':
generated_file = ""
#################################
# Write all examples into a TFRecords file
#################################
with tempfile.NamedTemporaryFile(dir=".", delete=False) as fp:
writer = tf.python_io.TFRecordWriter(fp.name)
generated_file = fp.name
for sequence, label_sequence in zip(sequences, label_sequences):
ex = make_example(sequence, label_sequence)
writer.write(ex.SerializeToString())
writer.close()
#################################
# Read contents of TFRecord file
#################################
filename = os.path.join(os.getcwd(), generated_file)
print(filename)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
filename_queue = tf.train.string_input_producer([filename])
context_parsed, sequence_parsed = parse_example(filename_queue)
print(context_parsed)
print(sequence_parsed)
# THIS IS WHERE I NEED HELP
# terminal freezes if the following lines are uncommented
#print(context_parsed["length"].eval())
#print(sequence_parsed["tokens"].eval())
#print(sequence_parsed["labels"].eval())
我仍然熟悉张量流,所以我很欣赏我做错的解释(不仅仅是代码修复)所以我将来也不会犯类似的错误。谢谢!
-
编辑:好的我已经改变了以上内容:
with tf.Session() as sess:
tf.train.start_queue_runners(sess=sess) # I included this line
sess.run(tf.initialize_all_variables())
但是这给了我以下错误:
错误:tensorflow:QueueRunner中的异常:尝试使用已关闭的会话。
我不确定这有什么问题。根据{{3}},雅罗斯拉夫的评论是正确的,但我不确定我做错了什么。我已经尝试了几种不同格式的代码,但我不确定它是什么,我不知道。
-
EDIT2:好的,我明白了。显然我需要生成一个协调器并将其传递给队列运行器以使其工作。这是我的最终代码:
with tf.Session() as sess:
coord = tf.train.Coordinator()
tf.train.start_queue_runners(coord=coord)
context_parsed, sequence_parsed = parse_example(filename_queue)
for i in range(3):
v1,v2,v3 = sess.run([context_parsed["length"], sequence_parsed["tokens"], sequence_parsed["labels"]])
print(v1,v2,v3)