tf.nn.rnn_cell.GRUCell是在CPU设备上构建的

时间:2018-07-24 00:03:29

标签: tensorflow lstm rnn seq2seq

我现在正在训练2层seq2seq模型,并使用了gru_cell。

def create_rnn_cell():
    encoDecoCell = tf.contrib.rnn.GRUCell(emb_dim)
    encoDecoCell = tf.contrib.rnn.DropoutWrapper(
                                                 encoDecoCell,
                                                 input_keep_prob=1.0,
                                                 output_keep_prob=0.7
                                                 )
    return encoDecoCell

encoder_mutil = tf.contrib.rnn.MultiRNNCell(
            [create_rnn_cell() for _ in range(num_layers)],
        )

query_encoder_emb = tf.contrib.rnn.EmbeddingWrapper(
                                        encoder_mutil, 
                                        embedding_classes=vocab_size,                                                              
                                        embedding_size=word_embedding
                                        )

时间轴对象用于获取图中每个节点的执行时间,我发现GRU_cell(包括MatMul)内部的大多数操作都发生在CPU设备上,这使其运行非常缓慢。我安装了tf-1.8的gpu版本。对此有何评论?我在这里想念什么吗? 我猜tf.variable_scope出问题了,因为我对训练数据使用了不同的存储桶。这就是我在不同的bucktes之间重用变量的方式:

for i, bucket in enumerate(buckets):
    with tf.variable_scope(name_or_scope="RNN_encoder", reuse=True if i > 0 else None) as var_scope:
        query_output, query_state = tf.contrib.rnn.static_rnn(query_encoder_emb,inputs=self.query[:bucket[0]],dtype=tf.float32)

execution time screenshot

1 个答案:

答案 0 :(得分:0)

我发现了问题。在EmbeddingWrapper的源代码中,使用了CPU。 tf.contrib.rnn.EmbeddingWrapper 我重写了此功能,现在它可以在GPU上运行并且速度更快。因此,如果要使用tf.contrib.rnn.EmbeddingWrapper,请务必小心。