我正在尝试将 MNIST 数据集加载到 3 个不同的 TensorFlow 数据集(训练、有效和测试)中,然后将每个数据集的内容序列化到示例 protobuf 中。我收到一个错误,我无法理解我做错了什么。
这是我的代码:
import tensorflow_datasets as tfds
train_ds, test_ds = tfds.load('fashion_mnist', split=['train', 'test'])
# first 90% of the train split is used for training
train_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train', to=90, unit='%'))
# last 10% of the train dataset is used for validation
val_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train', from_=90, to=100, unit='%'))
train_ds = train_ds.shuffle(buffer_size=5000) # 5000 is a magic number here
val_ds = val_ds.shuffle(buffer_size=5000)
test_ds = test_ds.shuffle(buffer_size=5000)
BytesList = tf.train.BytesList
FloatList = tf.train.FloatList
Int64List = tf.train.Int64List
Feature = tf.train.Feature
Features = tf.train.Features
Example = tf.train.Example
for data in train_ds:
image = tf.io.serialize_tensor(data["image"])
label = data["label"]
data_example = Example(
features=Features(
feature={
"image": Feature(bytes_list=BytesList(value=[image])),
"label": Feature(int64_list=Int64List(value=[label]))
}))
这是我得到的错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-10-7e28c1154183> in <module>
12 features=Features(
13 feature={
---> 14 "image": Feature(bytes_list=BytesList(value=[image])),
15 "label": Feature(int64_list=Int64List(value=[label]))
16 }))
TypeError: <tf.Tensor: shape=(), dtype=string, numpy=b'\x08\x04\x12\x0c\x12\x02\x08\x1c\x12\x02\x08\x1c\x12\x02 has type tensorflow.python.framework.ops.EagerTensor, but expected one of: bytes
我尝试不使用 tf.io.serialize_tensor(data["image"])
(而是直接将 data["image"]
传递给 Feature
构造函数),结果导致另一个错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-11-5d1317c1b7df> in <module>
12 features=Features(
13 feature={
---> 14 "image": Feature(bytes_list=BytesList(value=[data["image"]])),
15 "label": Feature(int64_list=Int64List(value=[label]))
16 }))
TypeError: <tf.Tensor: shape=(28, 28, 1), dtype=uint8, numpy=
array([[[ 0],
[ 0],
[ 0],
has type tensorflow.python.framework.ops.EagerTensor, but expected one of: bytes
我做错了什么,我该如何解决?