如何将tensorflow中的以下嵌入代码传输到pytorch?

时间:2019-03-13 03:31:11

标签: tensorflow pytorch

我在Tensorflow中有一个嵌入代码,如下所示:

self.input_u = tf.placeholder(tf.int32, [None, user_length], name="input_u")
with tf.name_scope("user_embedding"):
        self.W1 = tf.Variable(
            tf.random_uniform([user_vocab_size, embedding_size], -1.0, 1.0),
            name="W")
        self.embedded_user = tf.nn.embedding_lookup(self.W1, self.input_u)
        self.embedded_users = tf.expand_dims(self.embedded_user, -1)

我想用pytorch重写,该怎么做?

0 个答案:

没有答案