如何将字符串张量(从.tfrecords获取)转换为浮点张量?

时间:2019-03-21 12:59:03

标签: python tensorflow

我的网络输入来自包含int32的文件。它们存储为.tfrecords,如下所示:

  writer = tf.python_io.TFRecordWriter(output_file)
  with tf.gfile.FastGFile(file_path, 'rb') as f:
    data = f.read()
    example = tf.train.Example(features=tf.train.Features(feature={
      'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data])) }))
    writer.write(example.SerializeToString())

然后我像这样读取tfrecords文件:

with tf.name_scope(self.name):
  filename_queue = tf.train.string_input_producer([path])
  reader = tf.TFRecordReader()

  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      features={ 'data': tf.FixedLenFeature([], tf.string) })

  data = features['data']

在读取tfrecords之后,我得到了这样的字符串张量:

Tensor("X/ParseSingleExample/ParseSingleExample:0", shape=(), dtype=string)

我想首先将其转换为int32,因为这就是初始数据所代表的含义。之后,我需要结束一个浮点张量,有人可以指向我正确的方向吗?

tensorflow的PS Im新手,请告诉我是否可以提供更多有用的信息

1 个答案:

答案 0 :(得分:3)

这应该有帮助

data = features['data']
decoded = tf.decode_raw(data, tf.int32)

这将输出dtype tf.int32的张量。然后,您可以调整其形状并投射到tf.float32

decoded = tf.reshape(decoded, shape)
decoded = tf.cast(decoded, tf.float32)

如果要检查tf.Session之外的tfrecords文件的内容

for str_rec in tf.python_io.tf_record_iterator('file.tfrecords'):
    example = tf.train.Example()
    example.ParseFromString(str_rec)
    data_str = example.features.feature['data'].bytes_list.value[0])
    decoded = np.fromstring(data_str, dtype)

要验证张量的内容,您可以按照此answer

中的说明在图中注入打印节点。
# Add print operation
decoded = tf.Print(decoded, [decoded])