序列化数据与tenserflow TFRecordDataset代码

时间:2018-04-03 14:59:18

标签: python tensorflow tfrecord

我有一个大型的numpy整数数据集,我想用GPU进行分析。数据集太大,无法容纳GPU上的主内存,因此我尝试将它们序列化为TFRecord,然后使用API​​流式传输记录进行处理。下面的代码是示例代码:它想要创建一些伪数据,将其序列化为TFRecord对象,然后使用TF会话将数据读回内存,使用map()函数进行解析。我的原始数据在numpy数组的维度方面是非同质的,尽管每个都是一个3D数组,其中第一个轴的长度为10。当我制作假数据时,我使用随机数重新创建了非均匀性。我的想法是在序列化数据时存储每个图像的大小,我可以使用它来将每个阵列恢复到其原始大小。但是当我反序列化时,有两个问题:首先进入的数据与出来的数据不匹配(序列化不匹配反序列化)。其次,获取所有序列化数据的迭代器是不正确的。这是代码:

import numpy as np
from skimage import io
from skimage.io import ImageCollection 
import tensorflow as tf
import argparse

#A function for parsing TFRecords
def record_parser(record):
    keys_to_features = {
            'fil' : tf.FixedLenFeature([],tf.string),
            'm'   : tf.FixedLenFeature([],tf.int64),
            'n'   : tf.FixedLenFeature([],tf.int64)} 

    parsed = tf.parse_single_example(record, keys_to_features)

    m    = tf.cast(parsed['m'],tf.int64)
    n    = tf.cast(parsed['n'],tf.int64)

    fil_shape = tf.stack([10,m,n])
    fil = tf.decode_raw(parsed['fil'],tf.float32)
    print("size: ", tf.size(fil))
    fil = tf.reshape(fil,fil_shape)
    return (fil,m,n)

#For writing and reading from the TFRecord
filename = "test.tfrecord"

if __name__ == "__main__":

    #Create the TFRecordWriter
    data_writer = tf.python_io.TFRecordWriter(filename)

    #Create some fake data
    files = []
    i_vals = np.random.randint(20,size=10)
    j_vals = np.random.randint(20,size=10)

    print(i_vals)
    print(j_vals)
    for x in range(5):
        files.append(np.random.rand(10,i_vals[x],j_vals[x]).astype(np.float32))

    i=0
    #Serialize the fake data and record it as a TFRecord using the TFRecordWriter
    for fil in files:
        i+=1
        f,m,n = fil.shape
        fil_raw = fil.tostring()
        print(fil.shape)
        example = tf.train.Example(
            features = tf.train.Features(
                feature = {
                    'fil' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[fil_raw])),
                    'm'   : tf.train.Feature(int64_list=tf.train.Int64List(value=[m])),
                    'n'   : tf.train.Feature(int64_list=tf.train.Int64List(value=[n]))
                }
            )
        )
        data_writer.write(example.SerializeToString())
    data_writer.close()

    #Deserialize and report on the fake data
    sess = tf.Session()

    dataset = tf.data.TFRecordDataset([filename])
    dataset = dataset.map(record_parser)

    iterator = dataset.make_initializable_iterator()

    next_element = iterator.get_next()

    sess.run(iterator.initializer)
    while True:
        try:
            sess.run(next_element)
            fil,m,n = (next_element[0],next_element[1],next_element[2])
            with sess.as_default():
                print("fil.shape: ",fil.eval().shape)
                print("M: ",m.eval())
                print("N: ",n.eval())
        except tf.errors.OutOfRangeError:
            break

这是输出:

MacBot$ python test.py
/Users/MacBot/anaconda/envs/tflow/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
[ 6  7  3 18  9 10  4  0  3 12]
[ 4  2 14  4 11  4  5  2  9 17]
(10, 6, 4)
(10, 7, 2)
(10, 3, 14)
(10, 18, 4)
(10, 9, 11)
2018-04-03 10:52:29.324429: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
size:  Tensor("Size:0", shape=(), dtype=int32)
fil.shape:  (10, 7, 2)
M:  3
N:  4

任何人都明白我做错了什么?谢谢你的帮助!

1 个答案:

答案 0 :(得分:0)

而不是

sess.run(iterator.initializer)
while True:
    try:
        sess.run(next_element)
        fil,m,n = (next_element[0],next_element[1],next_element[2])
        with sess.as_default():
            print("fil.shape: ",fil.eval().shape)
            print("M: ",m.eval())
            print("N: ",n.eval())
    except tf.errors.OutOfRangeError:
        break

应该是

sess.run(iterator.initializer)
while True:
    try:
        fil,m,n = sess.run(next_element)
        print("fil.shape: ",fil.eval().shape)
        print("M: ",m.eval())
        print("N: ",n.eval())
    except tf.errors.OutOfRangeError:
        break