从Tensorflow中的一个TFRecord示例中读取多个特征向量

时间:2018-07-19 22:47:48

标签: tensorflow tfrecord

我知道如何将每个示例的一个功能存储在tfrecord文件中,然后使用类似的方法读取它:

import tensorflow as tf
import numpy as np
import os


# This is used to parse an example from tfrecords
def parse(serialized_example):
  features = tf.parse_single_example(
    serialized_example,
    features ={
      "label": tf.FixedLenFeature([], tf.string, default_value=""),
      "feat": tf.FixedLenFeature([], tf.string, default_value="")
    })

  feat = tf.decode_raw(features['feat'], tf.float64)
  label = tf.decode_raw(features['label'], tf.int64)

  return feat, label


################# Generate data

cwd = os.getcwd()
numdata = 10
with tf.python_io.TFRecordWriter(os.path.join(cwd, 'data.tfrecords')) as writer:
    for i in range(numdata):
        feat = np.random.randn(2)
        label = np.array(np.random.randint(0,9))

        featb  = feat.tobytes()
        labelb = label.tobytes()
        import pudb.b
        example = tf.train.Example(features=tf.train.Features(
            feature={
            'feat': tf.train.Feature(bytes_list=tf.train.BytesList(value=[featb])),
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[labelb])),}))
        writer.write(example.SerializeToString())

        print('wrote f {}, l {}'.format(feat, label))

print('Done writing! Start reading and printing data')

################# Read data

filename = ['data.tfrecords']
dataset = tf.data.TFRecordDataset(filename).map(parse)
dataset = dataset.batch(100)
iterator = dataset.make_initializable_iterator()
feat, label = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    try:
        while True:
            example = sess.run((feat,label))
            print example
    except tf.errors.OutOfRangeError:
        pass

如果每个示例中都有多个特征向量+标签,该怎么办。例如,在上面的代码中,如果专长存储为2D数组。我仍然想做与以前相同的事情,即训练每个标签具有一个功能的DNN,但是tfrecords文件中的每个示例都具有多个功能和多个标签。这应该很简单,但是我在使用tfrecords在tensorflow中解压缩多个功能时遇到了麻烦。

1 个答案:

答案 0 :(得分:1)

首先,请注意np.ndarray.tobytes()将多维数组展平到一个列表中,即

feat = np.random.randn(N, 2)
reshaped = np.reshape(feat, (N*2,))
feat.tobytes() == reshaped.tobytes()   ## True

因此,如果您有一个N * 2数组以TFRecord格式另存为字节,则必须在解析后重新调整其形状。

如果这样做,则可以取消批处理tf.data.Dataset的元素,以便每次迭代都为您提供一个功能和一个标签。您的代码应如下所示:

# This is used to parse an example from tfrecords
def parse(serialized_example):
  features = tf.parse_single_example(
    serialized_example,
    features ={
      "label": tf.FixedLenFeature([], tf.string, default_value=""),
      "feat": tf.FixedLenFeature([], tf.string, default_value="")
    })

  feat = tf.decode_raw(features['feat'], tf.float64)    # array of shape (N*2, )
  feat = tf.reshape(feat, (N, 2))                       # array of shape (N, 2)
  label = tf.decode_raw(features['label'], tf.int64)    # array of shape (N, )

  return feat, label


################# Generate data

cwd = os.getcwd()
numdata = 10
with tf.python_io.TFRecordWriter(os.path.join(cwd, 'data.tfrecords')) as writer:
    for i in range(numdata):
        feat = np.random.randn(N, 2)
        label = np.array(np.random.randint(0,9, N))

        featb  = feat.tobytes()
        labelb = label.tobytes()
        example = tf.train.Example(features=tf.train.Features(
            feature={
            'feat': tf.train.Feature(bytes_list=tf.train.BytesList(value=[featb])),
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[labelb])),}))
        writer.write(example.SerializeToString())

        print('wrote f {}, l {}'.format(feat, label))

print('Done writing! Start reading and printing data')

################# Read data

filename = ['data.tfrecords']
dataset = tf.data.TFRecordDataset(filename).map(parse).apply(tf.contrib.data.unbatch())
... etc