从tfRecords中提取时找不到功能

时间:2018-06-29 19:23:18

标签: python tensorflow tensorflow-datasets

我正在尝试使用tfRecords存储两个numpy数组和一些元信息,然后将其作为tf.data.TFRecordDataset加载,以使用非常方便的批处理。另外,对于tensorflow和我的ram来说,数组太大了,所以它是剩下的少数选择之一。

def toTfRecords(Hdf5NormalizedImage, Hdf5GroundTruth):
        featureList = keys(Hdf5NormalizedImage)
        gtList = keys(Hdf5GroundTruth)

    writer = tf.python_io.TFRecordWriter('../src/tfRecord.tfrecords')
    for i in tqdm(range(featureList.__len__())):
        featurePadded = np.zeros(shape=(448, 448, 3))

        feature = Hdf5NormalizedImage[featureList[i]][:]
        groundtruth = Hdf5GroundTruth[featureList[i]][:]

        padded = featurePadded[:feature.shape[0], :feature.shape[1], :feature.shape[2]] = feature
        padMap = padded.shape
        grdMap = groundtruth.shape

        features = {
            'feature': tf.train.Feature(float_list=(tf.train.FloatList(value = feature))),
            'groundtruth': tf.train.Feature(float_list=(tf.train.FloatList(value = groundtruth))),
            'padMap': tf.train.Feature(float_list=(tf.train.FloatList(value = padMap))),
            'grdMap': tf.train.Feature(float_list=(tf.train.FloatList(value = grdMap))),
        }

        example = tf.train.Example(features=tf.train.Features(feature=features))
        writer.write(example.SerializeToString())
    writer.close()
    return

def parse_proto(example_proto):
    features = {
        'feature': tf.FixedLenFeature([], tf.float32),
        'groundtruth': tf.FixedLenFeature([], tf.float32),
        'padMap': tf.FixedLenFeature([], tf.int64),
        'grdMap': tf.FixedLenFeature([], tf.int64),
    }
    parsed_features = tf.parse_single_example(example_proto, features)

    return parsed_features['feature'], parsed_features['groundtruth'], parsed_features['padMap'], parsed_features['grdMap']


def read_tfrecords():
    dataset = tf.data.TFRecordDataset("../src/tfRecord.tfrecords")
    dataset = dataset.map(parse_proto)
    dataset = dataset.shuffle(256)
    dataset = dataset.repeat()
    dataset = dataset.batch(128)
    return dataset

graph = tf.Graph()
    with graph.as_default():
        tfDataset = read_tfrecords()
        iter = tfDataset.make_one_shot_iterator()

        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=graph) as sess:
        a = sess.run(iter.get_next())

错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Feature: feature (data type: float) is required but could not be found.
     [[Node: ParseSingleExample/ParseSingleExample = ParseSingleExample[Tdense=[DT_FLOAT, DT_INT64, DT_FLOAT, DT_INT64], dense_keys=["feature", "grdMap", "groundtruth", "padMap"], dense_shapes=[[], [], [], []], num_sparse=0, sparse_keys=[], sparse_types=[]](arg0, ParseSingleExample/Const, ParseSingleExample/Const_1, ParseSingleExample/Const_2, ParseSingleExample/Const_3)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?]], output_types=[DT_FLOAT, DT_FLOAT, DT_INT64, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

错误指向a = sess.run(iter.get_next())的行

0 个答案:

没有答案