我正在尝试构建一个seq2seq模型来生成句子并将其分类。为此,我想将解码器的softmax输出提供给预训练的分类器。但是,此分类器使用Keras嵌入层,因此无法将原始softmax传递到分类器中。我以为我可以使用gumbel softmax来获得一键编码,然后使用我在这里(https://github.com/keras-team/keras/issues/2505)上找到的OneHotEmbedding层来解决这个问题。
Eric Jang为gumbel softmax提供了此TensorFlow代码,我想知道如何将其转换为Keras层。我特别对hard
属性感兴趣,该属性可确保正向传递的向量严格分类,但向后传递的梯度是gumbel softmax输出。我不知道如何在Keras中构建它。有人可以帮忙吗?
谢谢。