如何在Tensorflow中有效使用tf.bucket_by_sequence_length?

时间:2017-06-15 09:19:07

标签: python tensorflow deep-learning bucket

所以我试图使用来自Tensorflow的tf.bucket_by_sequence_length(),但却无法弄清楚如何使它工作。

基本上,它应该将序列(不同长度)作为输入并将序列桶作为输出,但它似乎不会这样工作。

从这个讨论: https://github.com/tensorflow/tensorflow/issues/5609 我的印象是它需要一个队列才能按顺序提供此功能。但现在还不清楚。

功能文档可在此处找到:https://www.tensorflow.org/versions/r0.12/api_docs/python/contrib.training/bucketing#bucket_by_sequence_length

2 个答案:

答案 0 :(得分:1)

实际上,您需要将输入张量作为队列,例如可以是tf.FIFOQueue().deque()tf.TensorArray().read(tf.train.range_input_producer())

这款笔记本解释得非常好:

https://github.com/wcarvalho/jupyter_notebooks/blob/ebe762436e2eea1dff34bbd034898b64e4465fe4/tf.bucket_by_sequence_length/bucketing%20practice.ipynb

答案 1 :(得分:0)

我的以下回答基于Tensorflow2.0。我可以看到您可能正在使用旧版本的Tensorflow。但是,如果碰巧使用了新版本,则可以通过以下方式有效地使用bucket_by_sequence_length API。

precedencegroup Chaining {
    associativity: left
}

infix operator » : Chaining

extension Result {
    static func »<T>(value: Self, key: KeyPath<Success, T>) -> T? {
        switch value {
        case .success(let win):
            return win[keyPath: key]
        case .failure(let fail):
            print(fail.localizedDescription)
            return nil
        }
    }
}

// I included a custom type so that it could be customizable if needed.  
enum Err<Wrapped> {
    static func »<T>(value: Self, key: KeyPath<Wrapped, T>) -> T? {
        switch value {
        case .error(let err):
            print(err.localizedDescription)
            return nil
        case .some(let wrapper):
            return wrapper[keyPath: key]
        }
    }
    case error(Error)
    case some(Wrapped)
}

func errorWrapped() -> Err<String> {
    .some("Hello World")
}

func pleaseWork()  {
    print(errorWrapped()»\.isEmpty)
}