尝试构建示例 protobuf 时出现类型错误

时间:2021-03-20 19:00:42

标签: python tensorflow machine-learning

我正在尝试将 MNIST 数据集加载到 3 个不同的 TensorFlow 数据集(训练、有效和测试)中,然后将每个数据集的内容序列化到示例 protobuf 中。我收到一个错误,我无法理解我做错了什么。

这是我的代码:

import tensorflow_datasets as tfds

train_ds, test_ds = tfds.load('fashion_mnist', split=['train', 'test'])

# first 90% of the train split is used for training
train_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train', to=90, unit='%'))
# last 10% of the train dataset is used for validation
val_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train', from_=90, to=100, unit='%'))

train_ds = train_ds.shuffle(buffer_size=5000) # 5000 is a magic number here
val_ds = val_ds.shuffle(buffer_size=5000)
test_ds = test_ds.shuffle(buffer_size=5000)

BytesList = tf.train.BytesList
FloatList = tf.train.FloatList
Int64List = tf.train.Int64List
Feature = tf.train.Feature
Features = tf.train.Features
Example = tf.train.Example

for data in train_ds:
    image = tf.io.serialize_tensor(data["image"])
    label = data["label"]
    data_example = Example(
        features=Features(
            feature={
                "image": Feature(bytes_list=BytesList(value=[image])),
                "label": Feature(int64_list=Int64List(value=[label]))
            }))

这是我得到的错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-7e28c1154183> in <module>
     12         features=Features(
     13             feature={
---> 14                 "image": Feature(bytes_list=BytesList(value=[image])),
     15                 "label": Feature(int64_list=Int64List(value=[label]))
     16             }))

TypeError: <tf.Tensor: shape=(), dtype=string, numpy=b'\x08\x04\x12\x0c\x12\x02\x08\x1c\x12\x02\x08\x1c\x12\x02 has type tensorflow.python.framework.ops.EagerTensor, but expected one of: bytes

我尝试不使用 tf.io.serialize_tensor(data["image"])(而是直接将 data["image"] 传递给 Feature 构造函数),结果导致另一个错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-11-5d1317c1b7df> in <module>
     12         features=Features(
     13             feature={
---> 14                 "image": Feature(bytes_list=BytesList(value=[data["image"]])),
     15                 "label": Feature(int64_list=Int64List(value=[label]))
     16             }))

TypeError: <tf.Tensor: shape=(28, 28, 1), dtype=uint8, numpy=
array([[[  0],
        [  0],
        [  0],
     has type tensorflow.python.framework.ops.EagerTensor, but expected one of: bytes

我做错了什么,我该如何解决?

0 个答案:

没有答案