我有一个tfrecord文件,其中我存储了一个数据列表,每个元素都有2d坐标和3d坐标。坐标是dd float64的2d numpy数组。
这些是我用来存储它们的功能。
feature = {'train/coord2d': _floats_feature(projC),
'train/coord3d': _floats_feature(sChair)}
然后我将它们展平为浮动列表。
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten()))
现在我正在努力收听它们,以便我可以将它们送入我的网络进行训练。我想要2d coords作为输入,3d要成为训练我的netwrok的输出。
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer(filename, name='queue')
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features= {'train/coord2d': tf.FixedLenFeature([], tf.float32),
'train/coord3d': tf.FixedLenFeature([], tf.float32)})
coord2d = tf.cast(features['train/coord2d'], tf.float32)
coord3d = tf.cast(features['train/coord3d'], tf.float32)
return coord2d, coord3d
with tf.Session() as sess:
filename = ["train.tfrecords"]
dataset = tf.data.TFRecordDataset(filename)
c2d, c3d = read_and_decode(filename)
print(sess.run(c2d))
print(sess.run(c3d))
这是我的代码,但我真的不明白它,因为我从教程等得到它所以我试图打印出c2d和c3d,看看他们的格式,但我的程序只是保持运行,并没有打印任何东西,从来没有终止。 c2d和c3d是否包含数据集中每个元素的2d和3d坐标?在培训网络作为输入和输出时,它们可以直接使用吗?
在将它们用作网络输入之前,我也不知道应该采用什么格式。我应该将它们转换回2d numpy数组或2d张量?在哪种情况下我该怎么办?整体而言我只是非常失落,所以任何guidace都会非常有帮助!感谢
答案 0 :(得分:6)
tf.data.TFRecordDataset(filename)
位于右侧,但问题是dataset
未与您传递给sess.run()
的张量相关联。
这是一个应该产生一些输出的简单示例程序:
def decode(serialized_example):
# NOTE: You might get an error here, because it seems unlikely that the features
# called 'coord2d' and 'coord3d', and produced using `ndarray.flatten()`, will
# have a scalar shape. You might need to change the shape passed to
# `tf.FixedLenFeature()`.
features = tf.parse_single_example(
serialized_example,
features={'train/coord2d': tf.FixedLenFeature([], tf.float32),
'train/coord3d': tf.FixedLenFeature([], tf.float32)})
# NOTE: No need to cast these features, as they are already `tf.float32` values.
return features['train/coord2d'], features['train/coord3d']
filename = ["train.tfrecords"]
dataset = tf.data.TFRecordDataset(filename).map(decode)
iterator = dataset.make_one_shot_iterator()
c2d, c3d = iterator.get_next()
with tf.Session() as sess:
try:
while True:
print(sess.run((c2d, c3d)))
except tf.errors.OutOfRangeError:
# Raised when we reach the end of the file.
pass