我正在尝试使用来自Tensorflow的CIFAR10示例创建一个使用tf.RandomShuffleQueue
的列车操作,我的标签来自(Accessing filename from file queue in Tensor Flow)中提到的文件名。我该如何使用此代码?
当我尝试运行以下代码时,path
是一个包含许多文件的目录:
filenames = [path, f) for f in os.listdir(path)][1:]
file_fifo = tf.train.string_input_producer(filenames,
shuffle=False,
capacity=len(filenames))
reader = tf.WholeFileReader()
key, value = reader.read(file_fifo)
image = tf.image.decode_png(value, channels=3, dtype=tf.uint8)
image.set_shape([config.image_height, config.image_width, config.image_depth])
image = tf.cast(image, tf.float32)
image = tf.divide(image, 255.0)
labels = [int(os.path.basename(f).split('_')[-1].split('.')[0]) for f in filenames]
label_fifo = tf.FIFOQueue(len(filenames), tf.int32, shapes=[[]])
label_enqueue = label_fifo.enqueue_many([tf.constant(labels)])
label = label_fifo.dequeue()
bq = tf.RandomShuffleQueue(capacity=16 * batch_size,
min_after_dequeue=8 * batch,
dtypes=[tf.float32, tf.int32])
batch_enqueue_op = bq.enqueue([image, label_enqueue])
runner = tf.train.queue_runner.QueueRunner(bq, [batch_enqueue_op] * num_threads)
tf.train.add_queue_runner(runner)
# Read 'batch' labels + images from the example queue.
images, labels = batch_queue.dequeue_many(FLAGS.batch_size)
labels = tf.reshape(labels, [FLAGS.batch_size, 1])
我得到明显的错误,因为我知道我的代码没有多大意义。我有两个FIFO队列file_fifo
和label_fifo
,但我不知道如何对我的tf.RandomShuffleQueue进行label_fifo
输入。
有人可以帮忙吗?谢谢: - )
答案 0 :(得分:0)
我将代码更改为:
filenames = [os.path.join(FLAGS.data_path, f) for f in os.listdir(FLAGS.data_path)][1:]
np.random.shuffle(filenames)
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
reader = tf.WholeFileReader()
key, value = reader.read(file_fifo)
image = tf.image.decode_png(value, channels=3, dtype=tf.uint8)
image.set_shape([config.image_height, config.image_width, config.image_depth])
image = tf.cast(image, tf.float32)
image = tf.divide(image, 255.0)
labels = [int(os.path.basename(f).split('_')[-1].split('.')[0]) for f in filenames]
label_fifo = tf.FIFOQueue(len(filenames), tf.int32, shapes=[[]])
label_enqueue = label_fifo.enqueue_many([tf.constant(labels)])
label = label_fifo.dequeue()
if is_train:
images, label_batch = tf.train.shuffle_batch([image, label],
batch_size=FLAGS.batch_size,
num_threads=FLAGS.num_threads,
capacity=16 * FLAGS.batch_size,
min_after_dequeue=8 * FLAGS.batch_size)
labels = tf.reshape(label_batch, [FLAGS.batch_size, 1])
对于培训,我有:
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
def before_run(self, run_context):
self._step += 1
self._start_time = time.time()
if self._step % int(config.train_examples / FLAGS.batch_size) == 0 or self._step == 0:
run_context.session.run(label_enqueue_op)
return tf.train.SessionRunArgs({'loss': loss, 'net': net})
我将训练视为:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_path,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()],
config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
训练开始,但它只在第一步运行并挂起 - 也许是因为它正在等待一些队列命令