通过遵循mnist示例,我能够构建自定义网络并使用示例的inputs
函数来加载我的数据集(以前编码为TFRecord
)。回顾一下,inputs
函数如下所示:
def inputs(train_dir, train, batch_size, num_epochs, one_hot_labels=False):
if not num_epochs: num_epochs = None
filename = os.path.join(train_dir,
TRAIN_FILE if train else VALIDATION_FILE)
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=num_epochs)
# Even when reading in multiple threads, share the filename
# queue.
image, label = read_and_decode(filename_queue)
# Shuffle the examples and collect them into batch_size batches.
# (Internally uses a RandomShuffleQueue.)
# We run this in two threads to avoid being a bottleneck.
images, sparse_labels = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=2,
capacity=1000 + 3 * batch_size,
# Ensures a minimum amount of shuffling of examples.
min_after_dequeue=1000)
return images, sparse_labels
然后,在训练期间,我宣布训练操作员并运行一切,一切顺利。
现在,我正在尝试使用相同的函数在相同的数据上训练不同的网络,唯一(主要)的区别在于,而不仅仅是在某些slim.learning.train
上调用train_operator
函数,我手动进行培训(通过手动评估损失和更新参数)。架构更复杂,我不得不这样做。
当我尝试使用inputs
函数生成的数据时,程序会卡住,设置队列超时确实表明它已经卡在生产者的队列中。
这让我相信我可能错过了关于在tensorflow中使用生产者的一些东西,我已经阅读了这些教程,但我无法弄清楚这个问题。是否存在调用slim.learning.train
的某种初始化,如果我手动进行培训,我需要手动复制?为什么生产者生产究竟不是什么?
例如,执行以下操作:
imgs, labels = inputs(...)
print imgs
打印
<tf.Tensor 'input/shuffle_batch:0' shape=(1, 128, 384, 6) dtype=float32>
这是正确的(符号?)张量,但如果我尝试用imgs.eval()
来获取实际数据,它会被无限期地卡住。