tf.contrib.rnn.DropoutWrapper

时间:2017-08-14 12:45:17

标签: python tensorflow

根据tf.contrib.rnn.DropoutWrapper的{​​{3}}:

  • output_keep_prob:单位Tensor或浮点数在0和1之间,输出保持概率;如果它是常数且为1,则不会添加输​​出丢失。
  • state_keep_prob:单位Tensor或浮点数在0和1之间,输出保持概率;如果它是常数和1,则不会添加输​​出丢失。状态丢失是在单元的输出状态上执行的。

这两个参数的描述几乎相同,对吧?

我将output_keep_prob设置为默认值state_keep_prob=0.2loss在400个小批量'之后始终在11.3左右。培训时,我将output_keep_prob=0.2state_keep_prob设置为默认值,我的模型返回的loss在20次小批量后迅速降至6.0左右!我花了4天时间找到这个bug,真的很神奇,谁能解释这两个参数之间的区别?非常感谢!

超级参数:

  • lr = 5E-4
  • batch_size = 32
  • state_size = 256
  • multirnn_depth = 2

这是API

2 个答案:

答案 0 :(得分:4)

  • state_keep_prob是添加到RNN隐藏状态的丢失。添加到时间步骤i状态的丢失将影响状态i+1, i+2, ...的计算。正如您所发现的,这种传播效应通常对学习过程有害。
  • output_keep_prob是添加到RNN输出的丢失,丢失对后续状态的计算没有影响。

答案 1 :(得分:1)

两者都被正确地称为输出保持概率,您应该使用哪一个取决于您是否决定使用输出状态来计算你的 logits

我正在提供一个代码片段供您使用并探索用例:

import tensorflow as tf
import numpy as np
tf.reset_default_graph()

# Create input data
X = np.random.randn(2, 20, 8)

# The first example is of length 6 
X[0,6:] = 0
X_lengths = [6, 20]
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size, state_is_tuple=True) for 
size in [3, 7]]
rnn_layers = [tf.nn.rnn_cell.DropoutWrapper(lstm_cell, 
state_keep_prob=0.8, output_keep_prob=0.8) for lstm_cell in 
rnn_layers]
# cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)

outputs, states  = tf.nn.dynamic_rnn(
                                     cell=multi_rnn_cell,
                                     dtype=tf.float64,
                                     sequence_length=X_lengths,
                                     inputs=X)

result = tf.contrib.learn.run_n(
{"outputs": outputs, "states": states},
n=1,
feed_dict=None)
assert result[0]["outputs"].shape == (2, 20, 7)
print (result[0]["states"][0].h)
print (result[0]["states"][-1].h)
print (result[0]["outputs"][0][5])
print (result[0]["outputs"][-1][-1])
print(result[0]["outputs"].shape)
print(result[0]["outputs"][0].shape)
print(result[0]["outputs"][1].shape)
assert (result[0]["outputs"][-1][-1]==result[0]["states" 
[-1].h[-1]).all()
assert (result[0]["outputs"][0][5]==result[0]["states"] 
[-1].h[0]).all()

result[0]["outputs"][0][6:]将是所有0的数组。

state_keep_proboutput_keep_prob <1时,断言都会失败但是当等于相同的值时,如本示例所示,您可以看到它们产生的辍学掩码除外同样的最终状态。

如果你有变量sequence_length,你绝对应该使用states来计算你的logits,在这种情况下,在训练时使用state_keep_prob&lt; 1。

如果你计划使用输出(应该在常量sequence_length的情况下使用,否则它需要进一步操作以在变量sequence_length的情况下获得最终的有效状态,或者你可能需要输出每一个时间步)你应该在训练时使用output_keep_prob

如果output_keep_probstate_keep_prob同时使用不同的相应丢失值,那么您会在outputsstates中看到最终返回状态中的不同值以及不同的丢失掩码。