tensorflow / contrib / cudnn_rnn / python / layers / cudnn_rnn.py中的__call __()在哪里

时间:2018-06-11 19:56:28

标签: python tensorflow lstm rnn

以下是使用CuDnnLSTM

的示例

https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py

  def _build_rnn_graph_cudnn(self, inputs, config, is_training):
"""Build the inference graph using CUDNN cell."""
inputs = tf.transpose(inputs, [1, 0, 2])
self._cell = tf.contrib.cudnn_rnn.CudnnLSTM(
    num_layers=config.num_layers,
    num_units=config.hidden_size,
    input_size=config.hidden_size,
    dropout=1 - config.keep_prob if is_training else 0)
params_size_t = self._cell.params_size()
self._rnn_params = tf.get_variable(
    "lstm_params",
    initializer=tf.random_uniform(
        [params_size_t], -config.init_scale, config.init_scale),
    validate_shape=False)
c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size],
             tf.float32)
h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size],
             tf.float32)
self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training)
outputs = tf.transpose(outputs, [1, 0, 2])
outputs = tf.reshape(outputs, [-1, config.hidden_size])
return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)

这一行:self._cell = tf.contrib.cudnn_rnn.CudnnLSTM创建一个LSTM图层。

然后它调用self._cell对象

outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training)

但是我没有在定义CuDnnLSTM的地方找到__call__()函数: https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py

如何在不定义调用()功能的情况下调用self._cell对象?

0 个答案:

没有答案