将Tensor转换为tf.train.Feature对象(TypeError:类型为Tensor,但应为以下类型之一:int,long)

时间:2019-08-14 11:06:02

标签: python python-3.x tensorflow keras tensor

我有一个TensorFlow Dataset,由(图像,标签)对组成。我现在希望为每个对象创建一个tf.train.Example对象,并将它们存储在TFRecord文件中。但是我一直遇到以下错误:

TypeError: <tf.Tensor 'Cast:0' shape=() dtype=int64> has type Tensor, but expected one of: int, long

以下是可重复的示例。我在做什么错了?

import tensorflow as tf
tf.enable_eager_execution()

# Download sample images.
cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')

# Put images and corresponding labels in lists.
images = [cat_in_snow, williamsburg_bridge]
labels = [0, 1]

# Create dataset from (image, label) pairs.
image_label_ds = tf.data.Dataset.from_tensor_slices((images, labels))

# Convenience features for creating a tf.train.Feature.
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Create tf.train.Example for each (image, label) pair.
def image_label_to_example(image, label):
    feature = {
        "label": _int64_feature(tf.cast(label, tf.int64)),
        "image": _bytes_feature(image)
    }

    tf_example = tf.train.Example(features=tf.train.Features(feature=feature))

    return tf_example.SerializeToString()

image_label_examples = image_label_ds.map(image_label_to_example)
---

TypeError                                 Traceback (most recent call last)
<ipython-input-37-8baca3115c8a> in <module>
-&#x2014;> 1 image<sub>label</sub><sub>examples</sub> = imgs<sub>lbs</sub><sub>ds.map</sub>(image<sub>label</sub><sub>to</sub><sub>example</sub>)

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in map(self, map<sub>func</sub>, num<sub>parallel</sub><sub>calls</sub>)
    1770     if num<sub>parallel</sub><sub>calls</sub> is None:
    1771       return DatasetV1Adapter(
-> 1772           MapDataset(self, map<sub>func</sub>, preserve<sub>cardinality</sub>=False))
    1773     else:
    1774       return DatasetV1Adapter(

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in \_<sub>init</sub>\_<sub>(self, input<sub>dataset</sub>, map<sub>func</sub>, use<sub>inter</sub><sub>op</sub><sub>parallelism</sub>, preserve<sub>cardinality</sub>, use<sub>legacy</sub><sub>function</sub>)</sub>
    3188         self.<sub>transformation</sub><sub>name</sub>(),
    3189         dataset=input<sub>dataset</sub>,
-> 3190         use<sub>legacy</sub><sub>function</sub>=use<sub>legacy</sub><sub>function</sub>)
    3191     variant<sub>tensor</sub> = gen<sub>dataset</sub><sub>ops.map</sub><sub>dataset</sub>(
    3192         input<sub>dataset</sub>.<sub>variant</sub><sub>tensor</sub>,  # pylint: disable=protected-access

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in \_<sub>init</sub>\_<sub>(self, func, transformation<sub>name</sub>, dataset, input<sub>classes</sub>, input<sub>shapes</sub>, input<sub>types</sub>, input<sub>structure</sub>, add<sub>to</sub><sub>graph</sub>, use<sub>legacy</sub><sub>function</sub>, defun<sub>kwargs</sub>)</sub>
    2553       resource<sub>tracker</sub> = tracking.ResourceTracker()
    2554       with tracking.resource<sub>tracker</sub><sub>scope</sub>(resource<sub>tracker</sub>):
-> 2555         self.<sub>function</sub> = wrapper<sub>fn</sub>.<sub>get</sub><sub>concrete</sub><sub>function</sub><sub>internal</sub>()
    2556         if add<sub>to</sub><sub>graph</sub>:
    2557           self.<sub>function.add</sub><sub>to</sub><sub>graph</sub>(ops.get<sub>default</sub><sub>graph</sub>())

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py in \_get<sub>concrete</sub><sub>function</sub><sub>internal</sub>(self, \*args, \*\*kwargs)
    1353     """Bypasses error checking when getting a graph function."""
    1354     graph<sub>function</sub> = self.<sub>get</sub><sub>concrete</sub><sub>function</sub><sub>internal</sub><sub>garbage</sub><sub>collected</sub>(
-> 1355         \*args, \*\*kwargs)
    1356     # We're returning this concrete function to someone, and they may keep a
    1357     # reference to the FuncGraph without keeping a reference to the

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py in \_get<sub>concrete</sub><sub>function</sub><sub>internal</sub><sub>garbage</sub><sub>collected</sub>(self, \*args, \*\*kwargs)
    1347     if self.input<sub>signature</sub>:
    1348       args, kwargs = None, None
-> 1349     graph<sub>function</sub>, \_, \_ = self.<sub>maybe</sub><sub>define</sub><sub>function</sub>(args, kwargs)
    1350     return graph<sub>function</sub>
    1351

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py in \_maybe<sub>define</sub><sub>function</sub>(self, args, kwargs)
    1650       graph<sub>function</sub> = self.<sub>function</sub><sub>cache.primary.get</sub>(cache<sub>key</sub>, None)
    1651       if graph<sub>function</sub> is None:
-> 1652         graph<sub>function</sub> = self.<sub>create</sub><sub>graph</sub><sub>function</sub>(args, kwargs)
    1653         self.<sub>function</sub><sub>cache.primary</sub>[cache<sub>key</sub>] = graph<sub>function</sub>
    1654       return graph<sub>function</sub>, args, kwargs

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/eager/function.py in \_create<sub>graph</sub><sub>function</sub>(self, args, kwargs, override<sub>flat</sub><sub>arg</sub><sub>shapes</sub>)
    1543             arg<sub>names</sub>=arg<sub>names</sub>,
    1544             override<sub>flat</sub><sub>arg</sub><sub>shapes</sub>=override<sub>flat</sub><sub>arg</sub><sub>shapes</sub>,
-> 1545             capture<sub>by</sub><sub>value</sub>=self.<sub>capture</sub><sub>by</sub><sub>value</sub>),
    1546         self.<sub>function</sub><sub>attributes</sub>)
    1547

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/framework/func<sub>graph.py</sub> in func<sub>graph</sub><sub>from</sub><sub>py</sub><sub>func</sub>(name, python<sub>func</sub>, args, kwargs, signature, func<sub>graph</sub>, autograph, autograph<sub>options</sub>, add<sub>control</sub><sub>dependencies</sub>, arg<sub>names</sub>, op<sub>return</sub><sub>value</sub>, collections, capture<sub>by</sub><sub>value</sub>, override<sub>flat</sub><sub>arg</sub><sub>shapes</sub>)
    713                                           converted<sub>func</sub>)
    714
&#x2013;> 715       func<sub>outputs</sub> = python<sub>func</sub>(\*func<sub>args</sub>, \*\*func<sub>kwargs</sub>)
    716
    717       # invariant: \`func<sub>outputs</sub>\` contains only Tensors, CompositeTensors,

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in wrapper<sub>fn</sub>(\*args)
    2547           attributes=defun<sub>kwargs</sub>)
    2548       def wrapper<sub>fn</sub>(\*args):  # pylint: disable=missing-docstring
-> 2549         ret = \_wrapper<sub>helper</sub>(\*args)
    2550         ret = self.<sub>output</sub><sub>structure</sub>.<sub>to</sub><sub>tensor</sub><sub>list</sub>(ret)
    2551         return [ops.convert<sub>to</sub><sub>tensor</sub>(t) for t in ret]

//anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset<sub>ops.py</sub> in \_wrapper<sub>helper</sub>(\*args)
    2487         nested<sub>args</sub> = (nested<sub>args</sub>,)
    2488
-> 2489       ret = func(\*nested<sub>args</sub>)
    2490       # If \`func\` returns a list of tensors, \`nest.flatten()\` and
    2491       # \`ops.convert<sub>to</sub><sub>tensor</sub>()\` would conspire to attempt to stack

<ipython-input-36-9376dfea556f> in image<sub>label</sub><sub>to</sub><sub>example</sub>(image, label)
    1 def image<sub>label</sub><sub>to</sub><sub>example</sub>(image, label):
    2     feature = {
-&#x2014;> 3         "label": \_int64<sub>feature</sub>(tf.cast(label, tf.int64)),
    4         "image": \_bytes<sub>feature</sub>(image)
    5     }

<ipython-input-35-b0b600e28608> in \_int64<sub>feature</sub>(value)
    12 def \_int64<sub>feature</sub>(value):
    13   """Returns an int64<sub>list</sub> from a bool / enum / int / uint."""
&#x2014;> 14   return tf.train.Feature(int64<sub>list</sub>=tf.train.Int64List(value=[value]))

TypeError: <tf.Tensor 'Cast:0' shape=() dtype=int64> has type Tensor, but expected one of: int, long

系统信息:
Python 3.7.3。
Tensorflow 1.14.0。
MacOS Mojave 10.14.6。

参考文献:
TFRecords and tf.Example
Load images

0 个答案:

没有答案