我正在尝试将一些索引和浮点值保存到 tfrecords ,并使用 tf.data API对其进行解析。
我正在使用 tensorflow2.1 版本 2.1.0rc1 在Ubuntu 18.04和python 3.7.5上
首先,我创建示例:
def create_example(row_ix, col_ix, cooc_val):
return tf.train.Example(
features=tf.train.Features(
feature={
"rows": tf.train.Feature(int64_list=tf.train.Int64List(value=[row_ix])),
"cols": tf.train.Feature(int64_list=tf.train.Int64List(value=[col_ix])),
"cooc": tf.train.Feature(float_list=tf.train.FloatList(value=[cooc_val]))
}
)
)
保存到 tfrecord
的方法with tf.io.TFRecordWriter("train.tfrecord") as writer:
for row_ix, col_ix, cooc_val in tqdm(zip(glove_rows, glove_cols, glove_data)):
example = create_example(row_ix, col_ix, cooc_val)
writer.write(example.SerializeToString())
现在,我使用 tf.data API打开 tfrecord 文件:
train_ds = tf.data.TFRecordDataset(os.path.join(data_path, "train.tfrecord"))
然后,我解析示例
feature_description = {
"rows": tf.io.FixedLenFeature(shape=[1], dtype=tf.int64, default_value=0),
"cols": tf.io.FixedLenFeature(shape=[1], dtype=tf.int64, default_value=0),
"cooc": tf.io.FixedLenFeature(shape=[1], dtype=tf.float32, default_value=0.0)
}
def _parse_function(example_proto):
return tf.io.parse_single_example(example_proto, feature_description)
train_ds = train_ds.map(_parse_function)
返回以下错误:
InvalidArgumentError: Could not parse example input, value: '
cols
C
rows
cooc
IA'
[[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]]