无法导入解析tfrecords序列化示例

时间:2019-12-22 21:09:30

标签: python tensorflow protocol-buffers tensorflow2.0 tfrecord

我正在尝试将一些索引和浮点值保存到 tfrecords ,并使用 tf.data API对其进行解析。

我正在使用 tensorflow2.1 版本 2.1.0rc1 在Ubuntu 18.04和python 3.7.5上

首先,我创建示例:

def create_example(row_ix, col_ix, cooc_val):
    return tf.train.Example(
        features=tf.train.Features(
            feature={
                "rows": tf.train.Feature(int64_list=tf.train.Int64List(value=[row_ix])),
                "cols": tf.train.Feature(int64_list=tf.train.Int64List(value=[col_ix])),
                "cooc": tf.train.Feature(float_list=tf.train.FloatList(value=[cooc_val]))
            }
        )
    )

保存到 tfrecord

的方法
with tf.io.TFRecordWriter("train.tfrecord") as writer:
    for row_ix, col_ix, cooc_val in tqdm(zip(glove_rows, glove_cols, glove_data)):
        example = create_example(row_ix, col_ix, cooc_val)
        writer.write(example.SerializeToString())

现在,我使用 tf.data API打开 tfrecord 文件:

train_ds = tf.data.TFRecordDataset(os.path.join(data_path, "train.tfrecord"))

然后,我解析示例

feature_description = {
    "rows": tf.io.FixedLenFeature(shape=[1], dtype=tf.int64, default_value=0),
    "cols": tf.io.FixedLenFeature(shape=[1], dtype=tf.int64, default_value=0),
    "cooc": tf.io.FixedLenFeature(shape=[1], dtype=tf.float32, default_value=0.0)
}

def _parse_function(example_proto):
  return tf.io.parse_single_example(example_proto, feature_description)

train_ds = train_ds.map(_parse_function)

返回以下错误:

InvalidArgumentError: Could not parse example input, value: '
cols
C
rows


cooc
IA'
     [[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]]

0 个答案:

没有答案