我正在尝试为TextSum模型创建自己的训练数据。根据我的理解,我需要将我的文章和摘要放到二进制文件中(在TFRecords中)。但是,我无法从原始文本文件创建自己的训练数据。我不太清楚格式,所以我尝试使用以下代码创建一个非常简单的二进制文件:
files = os.listdir(path)
writer = tf.python_io.TFRecordWriter("test_data")
for i, file in enumerate(files):
content = open(os.path.join(path, file), "r").read()
example = tf.train.Example(
features = tf.train.Features(
feature = {
'content': tf.train.Feature(bytes_list=tf.train.BytesList(value=[content]))
}
)
)
serialized = example.SerializeToString()
writer.write(serialized)
我尝试使用以下代码读出此test_data文件的值
reader = open("test_data", 'rb')
len_bytes = reader.read(8)
str_len = struct.unpack('q', len_bytes)[0]
example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
example_pb2.Example.FromString(example_str)
但我总是收到以下错误:
File "dailymail_corpus_to_tfrecords.py", line 34, in check_file
example_pb2.Example.FromString(example_str)
File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 770, in FromString
message.MergeFromString(s)
File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString
if self._InternalParse(serialized, 0, length) != length:
File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1117, in InternalParse
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/decoder.py", line 791, in _SkipLengthDelimited
raise _DecodeError('Truncated message.')
google.protobuf.message.DecodeError: Truncated message.
我不知道出了什么问题。如果您有任何建议可以解决这个问题,请告诉我。
答案 0 :(得分:3)
对于那些有同样问题的人。我必须查看TensorFlow的源代码,看看他们如何使用TFRecordWriter写出数据。我已经意识到他们实际上写了8个字节的长度,4个字节用于CRC校验,这意味着前12个字节用于标头。因为在TextSum代码中,示例二进制文件似乎只有8字节的标头,这就是为什么他们使用reader.read(8)获取数据的长度并将其余部分作为特征读取。
我的工作解决方案是:
reader = open("test_data", 'rb')
len_bytes = reader.read(8)
reader.read(4) #ignore next 4 bytes
str_len = struct.unpack('q', len_bytes)[0]
example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
example_pb2.Example.FromString(example_str)
答案 1 :(得分:2)
我希望你的textsum目录中有data_convert_example.py。如果没有,您可以在这篇文章中找到它:https://github.com/tensorflow/models/pull/379/files
使用python文件将给定的二进制玩具数据(文件名:数据目录中的数据转换为文本格式)。
python data_convert_example.py --command binary_to_text --in_file ../data/data --out_file ../data/result_text
您可以使用result_text格式查看实际的文本格式。
以该格式准备数据并使用相同的python脚本从text_to_binary转换并将结果用于training / testing / eval。