what exactly does 'tf.contrib.rnn.DropoutWrapper'' in tensorflow do? ( three citical questions)

时间:2017-08-04 12:50:24

标签: python-3.x tensorflow neural-network bayesian rnn

As I know, DropoutWrapper is used as follows

__init__(
cell,
input_keep_prob=1.0,
output_keep_prob=1.0,
state_keep_prob=1.0,
variational_recurrent=False,
input_size=None,
dtype=None,
seed=None
)

.

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=0.5)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

the only thing I know is that it is use for dropout while training. Here are my three questions

  1. What are input_keep_prob,output_keep_prob and state_keep_prob respectively? (I guess they define dropout probability of each part of RNN, but exactly where?)

  2. Is dropout in this context applied to RNN not only when training but also prediction process? If it's true, is there any way to decide whether I do or don't use dropout at prediction process?

  3. As API documents in tensorflow web page, if variational_recurrent=True dropout works according to the method on a paper "Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in Recurrent Neural Networks". https://arxiv.org/abs/1512.05287 " I understood this paper roughly. When I train RNN, I use 'batch' not single time-series. In this case, tensorflow automatically assign different dropout mask to different time-series in a batch?

2 个答案:

答案 0 :(得分:0)

  1. input_keep_prob 用于拟合特征权重时添加的辍学级别(包含概率)。 output_keep_prob 用于为每个RNN单元输出添加的辍学级别。 state_keep_prob 用于提供给下一层的隐藏状态。
  2. 您可以按以下方式初始化上述每个参数:
import tensorflow as tf
dropout_placeholder = tf.placeholder_with_default(tf.cast(1.0, tf.float32))
tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.BasicRNNCell(n_hidden_rnn),

input_keep_prob = dropout_placeholder, output_keep_prob = dropout_placeholder, 
state_keep_prob = dropout_placeholder)

在预测过程中或在训练过程中可以提供的其他任何东西,默认的辍学水平将为1。

  1. 掩蔽是针对适合的重量而不是批次中包含的顺序进行的。据我所知,这是整个批次的工作。

答案 1 :(得分:-3)

keep_prob = tf.cond(dropOut,lambda:tf.constant(0.9), lambda:tf.constant(1.0))

cells = rnn.DropoutWrapper(cells, output_keep_prob=keep_prob)