我正在尝试使用tfRecords存储两个numpy数组和一些元信息,然后将其作为tf.data.TFRecordDataset加载,以使用非常方便的批处理。另外,对于tensorflow和我的ram来说,数组太大了,所以它是剩下的少数选择之一。
def toTfRecords(Hdf5NormalizedImage, Hdf5GroundTruth):
featureList = keys(Hdf5NormalizedImage)
gtList = keys(Hdf5GroundTruth)
writer = tf.python_io.TFRecordWriter('../src/tfRecord.tfrecords')
for i in tqdm(range(featureList.__len__())):
featurePadded = np.zeros(shape=(448, 448, 3))
feature = Hdf5NormalizedImage[featureList[i]][:]
groundtruth = Hdf5GroundTruth[featureList[i]][:]
padded = featurePadded[:feature.shape[0], :feature.shape[1], :feature.shape[2]] = feature
padMap = padded.shape
grdMap = groundtruth.shape
features = {
'feature': tf.train.Feature(float_list=(tf.train.FloatList(value = feature))),
'groundtruth': tf.train.Feature(float_list=(tf.train.FloatList(value = groundtruth))),
'padMap': tf.train.Feature(float_list=(tf.train.FloatList(value = padMap))),
'grdMap': tf.train.Feature(float_list=(tf.train.FloatList(value = grdMap))),
}
example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(example.SerializeToString())
writer.close()
return
def parse_proto(example_proto):
features = {
'feature': tf.FixedLenFeature([], tf.float32),
'groundtruth': tf.FixedLenFeature([], tf.float32),
'padMap': tf.FixedLenFeature([], tf.int64),
'grdMap': tf.FixedLenFeature([], tf.int64),
}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features['feature'], parsed_features['groundtruth'], parsed_features['padMap'], parsed_features['grdMap']
def read_tfrecords():
dataset = tf.data.TFRecordDataset("../src/tfRecord.tfrecords")
dataset = dataset.map(parse_proto)
dataset = dataset.shuffle(256)
dataset = dataset.repeat()
dataset = dataset.batch(128)
return dataset
graph = tf.Graph()
with graph.as_default():
tfDataset = read_tfrecords()
iter = tfDataset.make_one_shot_iterator()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=graph) as sess:
a = sess.run(iter.get_next())
错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Feature: feature (data type: float) is required but could not be found.
[[Node: ParseSingleExample/ParseSingleExample = ParseSingleExample[Tdense=[DT_FLOAT, DT_INT64, DT_FLOAT, DT_INT64], dense_keys=["feature", "grdMap", "groundtruth", "padMap"], dense_shapes=[[], [], [], []], num_sparse=0, sparse_keys=[], sparse_types=[]](arg0, ParseSingleExample/Const, ParseSingleExample/Const_1, ParseSingleExample/Const_2, ParseSingleExample/Const_3)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?]], output_types=[DT_FLOAT, DT_FLOAT, DT_INT64, DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
错误指向a = sess.run(iter.get_next())
的行