光束搜索解码器Tensorflow 2.0

时间:2019-06-04 15:09:23

标签: tensorflow

我正在寻求在Tensorflow 2.0 alpha中实现具有注意力和波束搜索的神经网络排序序列。尽管其网站上的教程非常有用,但是由于contrib库已被弃用,因此我难以确定实现光束搜索的最佳方法-有人能指出正确的方向吗?

我试图使用TF2.0s升级脚本将我的tensorflow 1.X光束搜索升级到2.0,但是它不支持contrib库。

这是光束搜索代码寻找1.x的方式

decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                    cell=decoder_cell,
                    embedding=self.embeddings,
                    start_tokens=tf.fill([self.batch_size], tf.constant(2)),
                    end_token=tf.constant(3),
                    initial_state=initial_state,
                    beam_width=self.beam_width,
                    output_layer=self.projection_layer
                )
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
                    decoder, output_time_major=True, maximum_iterations=summary_max_len, scope=decoder_scope)
self.prediction = tf.transpose(outputs.predicted_ids, perm=[1, 2, 0])

1 个答案:

答案 0 :(得分:0)

很少有 Tensorflow 1.x API 迁移到 Tensorflow 2.x 中的不同 API。 Tf.contrib 就是这样一个库,它部分迁移到了 Tensorflow 插件。

因为 tf.contrib.seq2seq.BeamSearchDecoder 被移到了 tfa.seq2seq.BeamSearchDecoder in TFv2.x.

tfa.seq2seq.BeamSearchDecoder(
    cell: tf.keras.layers.Layer,
    beam_width: int,
    embedding_fn: Optional[Callable] = None,
    output_layer: Optional[tf.keras.layers.Layer] = None,
    length_penalty_weight: tfa.types.FloatTensorLike = 0.0,
    coverage_penalty_weight: tfa.types.FloatTensorLike = 0.0,
    reorder_tensor_arrays: bool = True,
    **kwargs
)