麻烦使用tf.data API,TFRecordDataset和序列化

时间:2018-04-03 13:17:20

标签: python tensorflow


import numpy as np
from skimage import io
from skimage.io import ImageCollection 
import tensorflow as tf
import argparse

#A function for parsing TFRecords
def record_parser(record):
    keys_to_features = {
            'fil' : tf.FixedLenFeature([],tf.string),
            'm'   : tf.FixedLenFeature([],tf.int64),
            'n'   : tf.FixedLenFeature([],tf.int64)} 

    parsed = tf.parse_single_example(record, keys_to_features)

    m    = tf.cast(parsed['m'],tf.int32)
    n    = tf.cast(parsed['n'],tf.int32)

    fil_shape = tf.stack([10,m,n])
    fil = tf.decode_raw(parsed['fil'],tf.float32)
    fil = tf.reshape(fil,fil_shape)

    return (fil,m,n)

#For writing and reading from the TFRecord
filename = "test.tfrecord"

if __name__ == "__main__":

    #Create the TFRecordWriter
    data_writer = tf.python_io.TFRecordWriter(filename)

    #Create some fake data
    files = []
    i_vals = np.random.randint(20,size=10)
    j_vals = np.random.randint(20,size=10)

    for x in range(5):

    #Serialize the fake data and record it as a TFRecord using the TFRecordWriter
    for fil in files:

        f,m,n = fil.shape
        fil_raw = fil.tostring()

        print("fil.shape: ",fil.shape)

        example = tf.train.Example(
            features = tf.train.Features(
                feature = {
                    'fil' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[fil_raw])),
                    'm'   : tf.train.Feature(int64_list=tf.train.Int64List(value=[m])),
                    'n'   : tf.train.Feature(int64_list=tf.train.Int64List(value=[n]))

    #Deserialize and report on the fake data
    sess = tf.Session()

    dataset = tf.data.TFRecordDataset([filename])
    dataset = dataset.map(record_parser)

    iterator = dataset.make_initializable_iterator()

    next_element = iterator.get_next()

    while True:
            fil,m,n = next_element
            print("fil.shape: ",file.shape)
            print("M: ",m)
            print("N: ",n)
        except tf.errors.OutOfRangeError:


MacBot$ python test.py
[ 2 12 17 18 19 15 11  5  0 12]
[13  5  3  5  2  6  5 11 12 10]
fil.shape:  (10, 2, 13)
fil.shape:  (10, 12, 5)
fil.shape:  (10, 17, 3)
fil.shape:  (10, 18, 5)
fil.shape:  (10, 19, 2)
2018-04-03 09:01:18.382870: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-04-03 09:01:18.420114: W tensorflow/core/framework/op_kernel.cc:1202] OP_REQUIRES failed at iterator_ops.cc:870 : Invalid argument: Input to reshape is a tensor with 520 values, but the requested shape has 260
     [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32](DecodeRaw, stack)]]
Traceback (most recent call last):
  File "/Users/MacBot/anaconda/envs/tflow/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1361, in _do_call
    return fn(*args)
  File "/Users/MacBot/anaconda/envs/tflow/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1340, in _run_fn
    target_list, status, run_metadata)
  File "/Users/MacBot/anaconda/envs/tflow/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 516, in __exit__
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 520 values, but the requested shape has 260
     [[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32](DecodeRaw, stack)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[10,?,?], [], []], output_types=[DT_FLOAT, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]


1 个答案:

答案 0 :

您似乎正在编写np.random.rand的结果。但是,这会返回float64个值。另一方面,您告诉Tensorflow将字节解释为float32。这是一个不匹配 - 并且可以解释为什么数字的数量是预期的两倍(因为字节数是原来的两倍!)。
