将DROPOUT添加到Tensorflow CIFAR10深度CNN示例

时间:2016-10-18 03:12:55

标签: tensorflow

我希望将dropout添加到tensorflow CIFAR10教程示例代码中,但是遇到了一些困难。

Deep MNIST tensorflow教程包含一个dropout示例,但它使用的是交互式图形,它与CIFAR10教程使用的方法不同。此外,CIFAR10教程不使用占位符,也不使用feed_dict将变量传递给优化器,MNIST模型使用该优化器来传递训练的丢失概率。

我在尝试什么:

在cifar10_train.train()中,我在默认图表下定义了丢失概率占位符;那就是:

def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)
    keep_drop_prob = = tf.placeholder(tf.float32)

在下面,仍然在train()模块中,当我通过调用cifar10.inference()构建计算图时,我也传递了keep_drop_prob占位符,如下所示:

"""Build a Graph that computes the logits predictions from the
inference model."""
logits = cifar10.inference(images, keep_drop_prob)

在cifar10.inference()模块中,我现在使用传递的keep_drop_prob占位符并使用它来定义我的dropout层,如下所示:

drop1 = tf.nn.dropout(norm1, keep_drop_prob)

现在我在计算损失时定义并传递keep_drop_prob的值,仍然在train()模块中,如下所示:

"""Calculate loss."""
loss = cifar10.loss(logits, labels, keep_drop_prob = 0.5)

然后在cifar10.loss()模块中,我在计算交叉熵时使用传递的keep_drop_prob值,如下所示:

"""Calculate the average cross entropy loss across the batch."""
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits, labels, keep_drop_prob, name='cross_entropy_per_example')

现在在这一点上,我不确定到目前为止我做了什么是正确的,以及我接下来需要做什么。

非常感谢任何帮助!

1 个答案:

答案 0 :(得分:2)

我相信我找到了解决方案。

似乎我在正确的轨道上,但是通过keep_drop_prob占位符有点过分了。

要添加辍学,我已完成以下工作:

我在cifar10_train.train()模块中添加了keep_drop_prob占位符,如下所示:

def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.Variable(0, trainable=False)
    keep_drop_prob = = tf.placeholder(tf.float32)

在cifar10_train.train()模块中构建图形时,我将它传递给占位符,但也定义了它的值

"""Build a Graph that computes the logits predictions from the
inference model."""
logits = cifar10.inference(images, keep_drop_prob=0.5)

在cifar10.inference()模块中,我现在使用传递的keep_drop_prob占位符并使用它来定义我的dropout层,并将其传递给激活摘要以登录tensorboard:

drop1 = tf.nn.dropout(norm1, keep_drop_prob)
_activation_summary(drop1) 

当我查看张量板图时,我看到我的辍学者在那里。我还可以在dropout op中查询keep_prob变量,并通过在构建logits图时更改我传递的值来影响其value属性。

我的下一个测试是将keep_drop_prob设置为1并设置为0,并确保从网络中获得预期的结果。

我不确定这是实施辍学的最有效方式,但我相当肯定它有效。

注意,我只有一个keep_drop_prob占位符,我传递给很多层的dropout(每个卷积atm后面一个)。我认为tensorflow为每个dropout op使用唯一的分布,而不是需要一个唯一的占位符。

编辑:不要忘记对eval模块进行必要的更改,但是为dropout传递值为1。