Tensorflow的API:seq2seq

时间:2017-03-06 06:47:20

标签: python machine-learning tensorflow recurrent-neural-network

我一直在关注https://github.com/kvfrans/twitch/blob/master/main.py教程,使用tensorflow创建和训练基于rnn的聊天机器人。根据我的理解,这些教程是在旧版本的tensorflow上编写的,所以有些部分已经过时,给我一个错误,如:

Traceback (most recent call last):
  File "main.py", line 33, in <module>
    outputs, last_state = tf.nn.seq2seq.rnn_decoder(inputs, initialstate, cell, loop_function=None, scope='rnnlm')
AttributeError: 'module' object has no attribute 'seq2seq'

我修复了其中的一部分,但无法弄清楚tf.nn.seq2seq.rnn_decoder的替代方案是什么以及新模块的参数应该是什么。我目前修正的内容:

tf.nn.rnn_cell.BasicLSTMCell(embedsize)改为 tf.contrib.rnn.BasicLSTMCell(embedsize)

tf.nn.rnn_cell.DropoutWrapper(lstm_cell,keep_prob)已更改为tf.contrib.rnn.DropoutWrapper(lstm_cell,keep_prob)

tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * numlayers)改为 tf.contrib.rnn.MultiRNNCell([lstm_cell] * numlayers)

有人可以帮我弄清楚tf.nn.seq2seq.rnn_decoder会是什么吗?

1 个答案:

答案 0 :(得分:3)

我认为你需要this

extension Collection {

    // EZSE : A parralelized map for collections, operation is non blocking
    public func pmap<R>(_ each: (Self.Iterator.Element) -> R) -> [R?] {
        let indices = indicesArray()
        var res = [R?](repeating: nil, count: indices.count)

        DispatchQueue.concurrentPerform(iterations: indices.count) { (index) in
            let elementIndex = indices[index]
            res[index] = each(self[elementIndex])
        }

        // Above code is non blocking so partial exec on most runs
        return res
    }

    /// EZSE : Helper method to get an array of collection indices
    private func indicesArray() -> [Self.Index] {
        var indicesArray: [Self.Index] = []
        var nextIndex = startIndex
        while nextIndex != endIndex {
            indicesArray.append(nextIndex)
            nextIndex = index(after: nextIndex)
        }
        return indicesArray
    }
}