我有一个TensorFlow Dataset
,由(图像,标签)对组成。我现在希望为每个对象创建一个tf.train.Example
对象,并将它们存储在TFRecord
文件中。但是我一直遇到以下错误:
TypeError: <tf.Tensor 'Cast:0' shape=() dtype=int64> has type Tensor, but expected one of: int, long
以下是可重复的示例。我在做什么错了?
import tensorflow as tf
tf.enable_eager_execution()
# Download sample images.
cat_in_snow = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
# Put images and corresponding labels in lists.
images = [cat_in_snow, williamsburg_bridge]
labels = [0, 1]
# Create dataset from (image, label) pairs.
image_label_ds = tf.data.Dataset.from_tensor_slices((images, labels))
# Convenience features for creating a tf.train.Feature.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
# Create tf.train.Example for each (image, label) pair.
def image_label_to_example(image, label):
feature = {
"label": _int64_feature(tf.cast(label, tf.int64)),
"image": _bytes_feature(image)
}
tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
return tf_example.SerializeToString()
image_label_examples = image_label_ds.map(image_label_to_example)
---
TypeError Traceback (most recent call last)
<ipython-input-37-8baca3115c8a> in <module>
-—> 1 image<sub>label</sub><sub>examples</sub> = imgs<sub>lbs</sub><sub>ds.map</sub>(image<sub>label</sub><sub>to</sub><sub>example</sub>)
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in map(self, map<sub>func</sub>, num<sub>parallel</sub><sub>calls</sub>)
1770 if num<sub>parallel</sub><sub>calls</sub> is None:
1771 return DatasetV1Adapter(
-> 1772 MapDataset(self, map<sub>func</sub>, preserve<sub>cardinality</sub>=False))
1773 else:
1774 return DatasetV1Adapter(
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in \_<sub>init</sub>\_<sub>(self, input<sub>dataset</sub>, map<sub>func</sub>, use<sub>inter</sub><sub>op</sub><sub>parallelism</sub>, preserve<sub>cardinality</sub>, use<sub>legacy</sub><sub>function</sub>)</sub>
3188 self.<sub>transformation</sub><sub>name</sub>(),
3189 dataset=input<sub>dataset</sub>,
-> 3190 use<sub>legacy</sub><sub>function</sub>=use<sub>legacy</sub><sub>function</sub>)
3191 variant<sub>tensor</sub> = gen<sub>dataset</sub><sub>ops.map</sub><sub>dataset</sub>(
3192 input<sub>dataset</sub>.<sub>variant</sub><sub>tensor</sub>, # pylint: disable=protected-access
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in \_<sub>init</sub>\_<sub>(self, func, transformation<sub>name</sub>, dataset, input<sub>classes</sub>, input<sub>shapes</sub>, input<sub>types</sub>, input<sub>structure</sub>, add<sub>to</sub><sub>graph</sub>, use<sub>legacy</sub><sub>function</sub>, defun<sub>kwargs</sub>)</sub>
2553 resource<sub>tracker</sub> = tracking.ResourceTracker()
2554 with tracking.resource<sub>tracker</sub><sub>scope</sub>(resource<sub>tracker</sub>):
-> 2555 self.<sub>function</sub> = wrapper<sub>fn</sub>.<sub>get</sub><sub>concrete</sub><sub>function</sub><sub>internal</sub>()
2556 if add<sub>to</sub><sub>graph</sub>:
2557 self.<sub>function.add</sub><sub>to</sub><sub>graph</sub>(ops.get<sub>default</sub><sub>graph</sub>())
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py in \_get<sub>concrete</sub><sub>function</sub><sub>internal</sub>(self, \*args, \*\*kwargs)
1353 """Bypasses error checking when getting a graph function."""
1354 graph<sub>function</sub> = self.<sub>get</sub><sub>concrete</sub><sub>function</sub><sub>internal</sub><sub>garbage</sub><sub>collected</sub>(
-> 1355 \*args, \*\*kwargs)
1356 # We're returning this concrete function to someone, and they may keep a
1357 # reference to the FuncGraph without keeping a reference to the
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py in \_get<sub>concrete</sub><sub>function</sub><sub>internal</sub><sub>garbage</sub><sub>collected</sub>(self, \*args, \*\*kwargs)
1347 if self.input<sub>signature</sub>:
1348 args, kwargs = None, None
-> 1349 graph<sub>function</sub>, \_, \_ = self.<sub>maybe</sub><sub>define</sub><sub>function</sub>(args, kwargs)
1350 return graph<sub>function</sub>
1351
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py in \_maybe<sub>define</sub><sub>function</sub>(self, args, kwargs)
1650 graph<sub>function</sub> = self.<sub>function</sub><sub>cache.primary.get</sub>(cache<sub>key</sub>, None)
1651 if graph<sub>function</sub> is None:
-> 1652 graph<sub>function</sub> = self.<sub>create</sub><sub>graph</sub><sub>function</sub>(args, kwargs)
1653 self.<sub>function</sub><sub>cache.primary</sub>[cache<sub>key</sub>] = graph<sub>function</sub>
1654 return graph<sub>function</sub>, args, kwargs
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py in \_create<sub>graph</sub><sub>function</sub>(self, args, kwargs, override<sub>flat</sub><sub>arg</sub><sub>shapes</sub>)
1543 arg<sub>names</sub>=arg<sub>names</sub>,
1544 override<sub>flat</sub><sub>arg</sub><sub>shapes</sub>=override<sub>flat</sub><sub>arg</sub><sub>shapes</sub>,
-> 1545 capture<sub>by</sub><sub>value</sub>=self.<sub>capture</sub><sub>by</sub><sub>value</sub>),
1546 self.<sub>function</sub><sub>attributes</sub>)
1547
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/func<sub>graph.py</sub> in func<sub>graph</sub><sub>from</sub><sub>py</sub><sub>func</sub>(name, python<sub>func</sub>, args, kwargs, signature, func<sub>graph</sub>, autograph, autograph<sub>options</sub>, add<sub>control</sub><sub>dependencies</sub>, arg<sub>names</sub>, op<sub>return</sub><sub>value</sub>, collections, capture<sub>by</sub><sub>value</sub>, override<sub>flat</sub><sub>arg</sub><sub>shapes</sub>)
713 converted<sub>func</sub>)
714
–> 715 func<sub>outputs</sub> = python<sub>func</sub>(\*func<sub>args</sub>, \*\*func<sub>kwargs</sub>)
716
717 # invariant: \`func<sub>outputs</sub>\` contains only Tensors, CompositeTensors,
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in wrapper<sub>fn</sub>(\*args)
2547 attributes=defun<sub>kwargs</sub>)
2548 def wrapper<sub>fn</sub>(\*args): # pylint: disable=missing-docstring
-> 2549 ret = \_wrapper<sub>helper</sub>(\*args)
2550 ret = self.<sub>output</sub><sub>structure</sub>.<sub>to</sub><sub>tensor</sub><sub>list</sub>(ret)
2551 return [ops.convert<sub>to</sub><sub>tensor</sub>(t) for t in ret]
//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in \_wrapper<sub>helper</sub>(\*args)
2487 nested<sub>args</sub> = (nested<sub>args</sub>,)
2488
-> 2489 ret = func(\*nested<sub>args</sub>)
2490 # If \`func\` returns a list of tensors, \`nest.flatten()\` and
2491 # \`ops.convert<sub>to</sub><sub>tensor</sub>()\` would conspire to attempt to stack
<ipython-input-36-9376dfea556f> in image<sub>label</sub><sub>to</sub><sub>example</sub>(image, label)
1 def image<sub>label</sub><sub>to</sub><sub>example</sub>(image, label):
2 feature = {
-—> 3 "label": \_int64<sub>feature</sub>(tf.cast(label, tf.int64)),
4 "image": \_bytes<sub>feature</sub>(image)
5 }
<ipython-input-35-b0b600e28608> in \_int64<sub>feature</sub>(value)
12 def \_int64<sub>feature</sub>(value):
13 """Returns an int64<sub>list</sub> from a bool / enum / int / uint."""
—> 14 return tf.train.Feature(int64<sub>list</sub>=tf.train.Int64List(value=[value]))
TypeError: <tf.Tensor 'Cast:0' shape=() dtype=int64> has type Tensor, but expected one of: int, long
系统信息:
Python 3.7.3。
Tensorflow 1.14.0。
MacOS Mojave 10.14.6。