张量流中的control_dependencies

时间:2017-03-27 19:12:16

标签: tensorflow deep-learning

我发现很难理解tensorflow中control_dependencies背后的概念。下面是tensorflow中lstm的代码片段。有人可以解释一下control_dependencies的概念以及为什么需要它吗?

num_nodes = 64

graph = tf.Graph()

with graph.as_default():

    #parameters

    # Input gate - input, previous output, and bias
    ix = tf.Variable(tf.truncated_normal(shape=(vocab_size, num_nodes), mean=-0.1, stddev=0.1))
    im = tf.Variable(tf.truncated_normal(shape=(num_nodes, num_nodes), mean=-0.1, stddev=0.1))
    ib = tf.Variable(tf.zeros(shape=(1, num_nodes)))

    # Forget gate - input, previous output, and bias
    fx = tf.Variable(tf.truncated_normal(shape=(vocab_size, num_nodes), mean=-0.1, stddev=0.1))
    fm = tf.Variable(tf.truncated_normal(shape=(num_nodes, num_nodes), mean=-0.1, stddev=0.1))
    fb = tf.Variable(tf.zeros(shape=(1, num_nodes)))

    # Memory cell - input, state, and bias
    cx = tf.Variable(tf.truncated_normal(shape=(vocab_size, num_nodes), mean=-0.1, stddev=0.1))
    cm = tf.Variable(tf.truncated_normal(shape=(num_nodes, num_nodes), mean=-0.1, stddev=0.1))
    cb = tf.Variable(tf.zeros(shape=(1, num_nodes)))

    # Output gate - input, previous output, and bias
    ox = tf.Variable(tf.truncated_normal(shape=(vocab_size, num_nodes), mean=-0.1, stddev=0.1))
    om = tf.Variable(tf.truncated_normal(shape=(num_nodes, num_nodes), mean=-0.1, stddev=0.1))
    ob = tf.Variable(tf.zeros(shape=(1, num_nodes)))

    # variables to save state and output across rollings
    saved_output = tf.Variable(tf.zeros(shape=(batch_size, num_nodes)), trainable=False)
    saved_state = tf.Variable(tf.zeros(shape=(batch_size, num_nodes)), trainable=False)

    #classifier weights and biases
    w = tf.Variable(tf.truncated_normal(shape=(num_nodes, vocab_size), mean=-0.1, stddev=0.1))
    b = tf.Variable(tf.zeros(shape=(vocab_size,)))

    #define a LSTM cell
    def lstm_cell(i,o,state):
        '''
        http://deeplearning.net/tutorial/lstm.html for definition
        '''
        input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)
        forget_gate = tf.sigmoid(tf.matmul(i,fx) + tf.matmul(o, fm) + fb)
        update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb
        state = forget_gate * state + input_gate * tf.tanh(update)
        output_gate = tf.sigmoid(tf.matmul(i, ox) + tf.matmul(o, om) + ob)

        return output_gate * tf.tanh(state), state

    #Input data
    train_data = list()
    for _ in range(num_unrollings + 1):
        train_data.append(tf.placeholder(tf.float32, shape=(batch_size, vocab_size)))
    train_inputs = train_data[:num_unrollings]
    train_labels = train_data[1:]  # labels are inputs shifted by one time step.

    # Unrolled LSTM loop.
    outputs = list()
    output = saved_output
    state = saved_state
    for i in train_inputs:
        output, state = lstm_cell(i,output, state)
        outputs.append(output)

    #state saving across unrollings
    with tf.control_dependencies([saved_output.assign(output) , saved_state.assign(state)]):

        #classifier
        logits = tf.nn.xw_plus_b(tf.concat(0, outputs), w, b)
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, tf.concat(0, train_labels)))

    #optimizer 
    global_step = tf.Variable(0)
    learning_rate = tf.train.exponential_decay(10.0, global_step, 5000, 0.1, staircase=True)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    gradients, v = zip(*optimizer.compute_gradients(loss))
    gradients, _ = tf.clip_by_global_norm(gradients, 1.25)
    optimizer = optimizer.apply_gradients(zip(gradients,v), global_step=global_step)

0 个答案:

没有答案