我试图弄清楚如何使用tfrecords编码。当前正在尝试处理几个字符串整数对,如果包含相同字符,则将字符串标记为1;如果不同,则标记为0。
编码/解码流程结束后,出现一个错误,我的数据集似乎为空,.tfrecords文件存在并且确实包含一些数据。 这是我的代码:
import tensorflow as tf
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
strings = [
'aaaa',
'ssss',
'qwer',
'asdf'
]
labels = [1,1,0,0]
filename = 'my_dataset.tfrecords'
writer = tf.python_io.TFRecordWriter(filename)
for i in range(0, len(strings)):
feature = {
'x': _bytes_feature(strings[i]),
'label': _int64_feature(labels[i])
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
# A .tfrecords file is created at this moment
# Now read the file contents
sess = tf.Session()
tf.global_variables_initializer()
def parser(record):
keys = {
'x': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
parsed = tf.parse_single_example(record, keys)
x = parsed['x']
label = tf.cast(parsed['label'], tf.int32)
return x, label
dataset = tf.data.TFRecordDataset(filenames=[filename])
dataset = dataset.map(parser)
iterator = dataset.make_one_shot_iterator()
x_next, label_next = iterator.get_next()
print sess.run(x_next)
print sess.run(label_next)
当前,尝试sess.run()时会引发异常“序列结束”。我试图让它显示实际的字符串标签对:
aaaa 1
请提出我的代码中可能有什么问题