x [:t:]在python中做了什么? <RNN>

时间:2017-09-08 18:49:21

标签: python numpy tensorflow

我一直在研究RNN的语言模型,并在一本书中找到了Tensorflow的代码。 但我真的不明白x[:t:]在下面的代码中做了什么..... 我是机器学习的初学者,如果有人知道的话,请给我一个线索。

=======代码======

def inference(x, n_batch, maxlen=None, n_hidden=None, n_out=None):
    def weight_variable(shape):
        initial = tf.truncated_normal(shape, stddev=0.01)
        return tf.Variable(initial)

    def bias_variable(shape):
        initial = tf.zeros(shape, dtype=tf.float32)
        return tf.Variable(initial)

    cell = tf.contrib.rnn.BasicRNNCell(n_hidden)
    initial_state = cell.zero_state(n_batch, tf.float32)

    state = initial_state
    outputs = []
    with tf.variable_scope('RNN'):
        for t in range(maxlen):
            if t > 0:
                tf.get_variable_scope().reuse_variables()
            (cell_output, state) = cell(x[:, t, :], state)
            outputs.append(cell_output)

    output = outputs[-1]

    V = weight_variable([n_hidden, n_out])
    c = bias_variable([n_out])
    y = tf.matmul(output, V) + c

    return y

1 个答案:

答案 0 :(得分:2)

看起来x是一个3D矩阵。在这种情况下,[:,t,:]在XZ平面中提取二维矩阵作为立方体的t切片。

>>> import numpy as np
>>> x = np.arange(27).reshape(3,3,3)
>>> x
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])
>>> x[:,1,:]
array([[ 3,  4,  5],
       [12, 13, 14],
       [21, 22, 23]])

:表示轴保持不变。 [:,:,:]将返回整个矩阵,[1,:,:]将沿第一个轴提取第二个切片:

>>> x[1,:,:]
array([[ 9, 10, 11],
       [12, 13, 14],
       [15, 16, 17]])

这里是相应的documentation