以下是使用CuDnnLSTM
的示例https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py
def _build_rnn_graph_cudnn(self, inputs, config, is_training):
"""Build the inference graph using CUDNN cell."""
inputs = tf.transpose(inputs, [1, 0, 2])
self._cell = tf.contrib.cudnn_rnn.CudnnLSTM(
num_layers=config.num_layers,
num_units=config.hidden_size,
input_size=config.hidden_size,
dropout=1 - config.keep_prob if is_training else 0)
params_size_t = self._cell.params_size()
self._rnn_params = tf.get_variable(
"lstm_params",
initializer=tf.random_uniform(
[params_size_t], -config.init_scale, config.init_scale),
validate_shape=False)
c = tf.zeros([config.num_layers, self.batch_size, config.hidden_size],
tf.float32)
h = tf.zeros([config.num_layers, self.batch_size, config.hidden_size],
tf.float32)
self._initial_state = (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training)
outputs = tf.transpose(outputs, [1, 0, 2])
outputs = tf.reshape(outputs, [-1, config.hidden_size])
return outputs, (tf.contrib.rnn.LSTMStateTuple(h=h, c=c),)
这一行:self._cell = tf.contrib.cudnn_rnn.CudnnLSTM
创建一个LSTM图层。
然后它调用self._cell对象
outputs, h, c = self._cell(inputs, h, c, self._rnn_params, is_training)
但是我没有在定义CuDnnLSTM的地方找到__call__()
函数:
https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
如何在不定义调用()功能的情况下调用self._cell对象?