TFRecord功能值错误

时间:2019-03-09 17:47:36

标签: python tensorflow tfrecord

我正在尝试训练一些嵌入,并将我的数据集转换为tfrecord形式。当我像这样向文件写入一个示例时:

tf_features = {
        'given': int64_feature(given),
        'context': bytes_feature(np.array(context).tostring())
}
writer.write(tf.train.Example(features=tf.train.Features(feature=tf_features)).SerializeToString())

其中int64_featurebytes_feature定义为:

def bytes_feature(val):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[val]))

def int64_feature(val):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[val]))

然后我打印出一个示例(给定的上下文)对,我得到类似(698, [686, 439, 464, 775])的信息。

但是,当我尝试像这样读取同一文件时:

def parse_example(w, tf_example):
    feats_dict = {
        'given': tf.FixedLenFeature([], tf.int64, default_value=0),
        'context': tf.FixedLenFeature([], tf.string)
    }
    features = tf.parse_single_example(tf_example, feats_dict)
    context = tf.decode_raw(features['context'], tf.int64)

    context_feats = dict()
    ctx_idx = 0
    for i in range(w):
        if i == w//2: continue
        context_feats['context%d' % ctx_idx] = context[ctx_idx]
        ctx_idx += 1

    return context_feats, features['given']

dataset = tf.data.TFRecordDataset([fname]).map(partial(parse_example, 5))
iterator = dataset.make_one_shot_iterator()

with tf.Session() as sess:
    iter_features, iter_labels = iterator.get_next()
    features = sess.run(iter_features)
    labels = sess.run(iter_labels)
    print(features, labels)

对于与之前相同的上下文对,我得到(464, [686, 439, 464, 775])。给定的标签始终是上下文标签中的第三个标签。

我已经盯着这个代码好几个小时了,但是很沮丧。有人知道怎么回事吗?

1 个答案:

答案 0 :(得分:0)

我想知道发生了什么,这是一个很愚蠢的错误。在以下几行中:

iter_features, iter_labels = iterator.get_next()
features = sess.run(iter_features)
labels = sess.run(iter_labels)

我运行sess.run两次,并且由于迭代器的行为,当我获取功能时,它返回正确的功能,但是当我获取标签时,它返回了标签的标签。下一个示例。

有意义的是,由于用于获取给定上下文对的滑动窗口,我得到的标签始终是上下文中的第三个标签。

我将上述行更改为:

iter_ex = iterator.get_next()
ex = sess.run(iter_ex)
print(ex)

它按预期运行。