Mypy:Tensorflow TFRecords“字节”对象的“ mypy”类型

时间:2019-03-14 00:59:01

标签: python python-3.x tensorflow binary mypy

这很奇怪。所以我正在创建一些Tensorflow TFRecords文件来编码数据。我想检查单个记录文件的mypy类型,该文件编码为二进制字符串。

现在,当我运行以下代码并检查字符串type()时,其指示为<class 'bytes'>。但是,当我使用mypy reveal_type()时,它表示error: Revealed type is 'Any',因此看起来mypy不能识别字节类型。这有意义吗?我真的不想将某些内容编码为Any,因为这实际上并不能帮助捕获我想用mypy.捕获的错误类型。

这是我用来生成错误的示例代码。我从新的tensorflow TFRecords guide中获取了代码,但最后几行是我自己的。

import tensorflow as tf
import numpy as np
import typing

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# the number of observations in the dataset
n_observations = int(1e4)

# boolean feature, encoded as False or True
feature0 = np.random.choice([False, True], n_observations)

# integer feature, random from 0 .. 4
feature1 = np.random.randint(0, 5, n_observations)

# string feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# float feature, from a standard normal distribution
feature3 = np.random.randn(n_observations)


def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.Example message ready to be written to a file.
  """

  # Create a dictionary mapping the feature name to the tf.Example-compatible
  # data type.

  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
reveal_type(serialized_example)

print(type(serialized_example)) 

0 个答案:

没有答案