我想使用张量流的辍学功能来检查我是否可以改善递归神经网络的结果(TPR,FPR)。 但是,我通过遵循指南来实现它。所以我不确定我是否犯了任何错误。但是如果我用来训练我的模型验证后的10个时期,我得到的结果几乎相同。这就是为什么我不确定是否正确使用辍学功能的原因。在下面的代码中这是正确的实现还是我做错了什么?如果我做对了所有事情,为什么为什么我得到的结果几乎相同?
hm_epochs = 10
n_classes = 2
batch_size = 128
chunk_size = 341
n_chunks = 5
rnn_size = 32
dropout_prop = 0.5 # Dropout, probability to drop a unit
batch_size_validation = 65536
x = tf.placeholder('float', [None, n_chunks, chunk_size])
y = tf.placeholder('float')
def recurrent_neural_network(x):
layer = {'weights':tf.Variable(tf.random_normal([rnn_size, n_classes])),
'biases':tf.Variable(tf.random_normal([n_classes]))}
x = tf.transpose(x, [1,0,2])
x = tf.reshape(x, [-1, chunk_size])
x = tf.split(x, n_chunks, 0)
lstm_cell = rnn.BasicLSTMCell(rnn_size)
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
output = tf.matmul(outputs[-1], layer['weights']) + layer['biases']
#DROPOUT Implementation -> is this code really working?
#The result is nearly the same after 20 epochs...
output_layer = tf.layers.dropout(output, rate=dropout_prop)
return output
def train_neural_network(x):
prediction = recurrent_neural_network(x)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction,labels=y))
optimizer = tf.train.AdamOptimizer().minimize(cost)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(1,hm_epochs+1):
epoch_loss = 0
for i in range(0, training_data.shape[0], batch_size):
epoch_x = np.array(training_data[i:i+batch_size, :, :], dtype='float')
epoch_y = np.array(training_labels[i:i+batch_size, :], dtype='float')
if len(epoch_x) != batch_size:
epoch_x = epoch_x.reshape((len(epoch_x), n_chunks, chunk_size))
else:
epoch_x = epoch_x.reshape((batch_size, n_chunks, chunk_size))
_, c = sess.run([optimizer, cost], feed_dict={x: epoch_x, y: epoch_y})
epoch_loss += c
train_neural_network(x)
print("rnn - finished!")
答案 0 :(得分:1)
以最基本的形式,辍学应该发生在单元格内部并应用于权重。您仅在之后应用它。 This article很好地解释了这一点,并提供了一些很好的可视化效果和很少的变化。
要在代码中使用它,您可以
实现自己的RNN单元,其中keep概率是用于初始化该单元的参数,或者是每次被调用时传入的参数。
使用rnn退出包装器here。