我发现很难理解tensorflow中control_dependencies背后的概念。下面是tensorflow中lstm的代码片段。有人可以解释一下control_dependencies的概念以及为什么需要它吗?
num_nodes = 64
graph = tf.Graph()
with graph.as_default():
#parameters
# Input gate - input, previous output, and bias
ix = tf.Variable(tf.truncated_normal(shape=(vocab_size, num_nodes), mean=-0.1, stddev=0.1))
im = tf.Variable(tf.truncated_normal(shape=(num_nodes, num_nodes), mean=-0.1, stddev=0.1))
ib = tf.Variable(tf.zeros(shape=(1, num_nodes)))
# Forget gate - input, previous output, and bias
fx = tf.Variable(tf.truncated_normal(shape=(vocab_size, num_nodes), mean=-0.1, stddev=0.1))
fm = tf.Variable(tf.truncated_normal(shape=(num_nodes, num_nodes), mean=-0.1, stddev=0.1))
fb = tf.Variable(tf.zeros(shape=(1, num_nodes)))
# Memory cell - input, state, and bias
cx = tf.Variable(tf.truncated_normal(shape=(vocab_size, num_nodes), mean=-0.1, stddev=0.1))
cm = tf.Variable(tf.truncated_normal(shape=(num_nodes, num_nodes), mean=-0.1, stddev=0.1))
cb = tf.Variable(tf.zeros(shape=(1, num_nodes)))
# Output gate - input, previous output, and bias
ox = tf.Variable(tf.truncated_normal(shape=(vocab_size, num_nodes), mean=-0.1, stddev=0.1))
om = tf.Variable(tf.truncated_normal(shape=(num_nodes, num_nodes), mean=-0.1, stddev=0.1))
ob = tf.Variable(tf.zeros(shape=(1, num_nodes)))
# variables to save state and output across rollings
saved_output = tf.Variable(tf.zeros(shape=(batch_size, num_nodes)), trainable=False)
saved_state = tf.Variable(tf.zeros(shape=(batch_size, num_nodes)), trainable=False)
#classifier weights and biases
w = tf.Variable(tf.truncated_normal(shape=(num_nodes, vocab_size), mean=-0.1, stddev=0.1))
b = tf.Variable(tf.zeros(shape=(vocab_size,)))
#define a LSTM cell
def lstm_cell(i,o,state):
'''
http://deeplearning.net/tutorial/lstm.html for definition
'''
input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)
forget_gate = tf.sigmoid(tf.matmul(i,fx) + tf.matmul(o, fm) + fb)
update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb
state = forget_gate * state + input_gate * tf.tanh(update)
output_gate = tf.sigmoid(tf.matmul(i, ox) + tf.matmul(o, om) + ob)
return output_gate * tf.tanh(state), state
#Input data
train_data = list()
for _ in range(num_unrollings + 1):
train_data.append(tf.placeholder(tf.float32, shape=(batch_size, vocab_size)))
train_inputs = train_data[:num_unrollings]
train_labels = train_data[1:] # labels are inputs shifted by one time step.
# Unrolled LSTM loop.
outputs = list()
output = saved_output
state = saved_state
for i in train_inputs:
output, state = lstm_cell(i,output, state)
outputs.append(output)
#state saving across unrollings
with tf.control_dependencies([saved_output.assign(output) , saved_state.assign(state)]):
#classifier
logits = tf.nn.xw_plus_b(tf.concat(0, outputs), w, b)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, tf.concat(0, train_labels)))
#optimizer
global_step = tf.Variable(0)
learning_rate = tf.train.exponential_decay(10.0, global_step, 5000, 0.1, staircase=True)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, 1.25)
optimizer = optimizer.apply_gradients(zip(gradients,v), global_step=global_step)