我对tf.train.string_input_producer
的工作原理有所怀疑。因此,假设我将filename_list作为输入参数提供给string_input_producer
。然后,根据文档https://www.tensorflow.org/programmers_guide/reading_data,这将创建一个FIFOQueue
,我可以在其中设置纪元号,随机播放文件名等。因此,就我而言,我有4个文件名(" db1.tfrecords"," db2.tfrecords" ...)。我使用tf.train.batch
来提供网络批量图片。另外,每个file_name / database包含一个人的一组图像。第二个数据库是第二个人,依此类推。到目前为止,我有以下代码:
tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"),
(common + "P21_db.tfrecords")]
filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue')
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'annotation_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.reshape(image, [height, width, 3])
annotation = tf.cast(features['annotation_raw'], tf.string)
min_after_dequeue = 100
num_threads = 4
capacity = min_after_dequeue + num_threads * batch_size
label_batch, images_batch = tf.train.batch([annotation, image],
shapes=[[], [112, 112, 3]],
batch_size=batch_size,
capacity=capacity,
num_threads=num_threads)
最后,当试图在自动编码器的输出处查看重建图像时,我从第一个数据库中获得了第一个图像,然后我开始从第二个数据库查看图像,依此类推。
我的问题:我怎么知道我是否在同一个时代?如果我在理智的时代内,如何合并我所有文件名中的一批图像呢?
最后,我尝试通过评估Session
中的局部变量来打印出纪元的价值,如下所示:
epoch_var = tf.local_variables()[0]
然后:
with tf.Session() as sess:
print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y.
非常感谢任何帮助!!
答案 0 :(得分:0)
所以我想到的是,使用tf.train.shuffle_batch_join
解决了我的问题,因为它开始改变来自不同数据集的图像。换句话说,每个批处理现在都包含来自所有数据集/ file_names的图像。这是一个例子:
def read_my_file_format(filename_queue):
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'annotation_raw': tf.FixedLenFeature([], tf.string)
})
# This is how we create one example, that is, extract one example from the database.
image = tf.decode_raw(features['image_raw'], tf.uint8)
# The height and the weights are used to
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
# The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the
# height and the weight to restore the original image back.
image = tf.reshape(image, [height, width, 3])
annotation = tf.cast(features['annotation_raw'], tf.string)
return annotation, image
def input_pipeline(filenames, batch_size, num_threads, num_epochs=None):
filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epoch, shuffle=False,
name='queue')
# Therefore, Note that here we have created num_threads readers to read from the filename_queue.
example_list = [read_my_file_format(filename_queue=filename_queue) for _ in range(num_threads)]
min_after_dequeue = 100
capacity = min_after_dequeue + num_threads * batch_size
label_batch, images_batch = tf.train.shuffle_batch_join(example_list,
shapes=[[], [112, 112, 3]],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
return label_batch, images_batch, example_list
label_batch, images_batch, input_ann_img = \
input_pipeline(tfrecords_filename_seq, batch_size, num_threads, num_epochs=num_epoch)
现在这将创建一些读者从FIFOQueue
读取,并在每个读者后将有一个不同的解码器。最后,在解码图像之后,它们将被馈送到调用Queue
之后创建的另一个tf.train.shuffle_batch_join
以向网络提供一批图像。