将sess.run转换为pytorch

时间:2020-04-27 09:44:47

标签: python numpy tensorflow pytorch

我正在尝试将tf的代码转换为pytorch。 我卡住的部分代码是sess.run。据我所知,pytorch不需要它,但是我找不到复制它的方法。我附上了代码。

TF:

ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
    for i in range(ebnos_db.shape[0]):
        ebno_db = ebnos_db[i]
        bers_no_training[i] += sess.run(ber, feed_dict={
            batch_size: samples,
            noise_var: ebnodb2noisevar(ebno_db, coderate)
        })
bers_no_training /= epochs

samples是一个int32,而ebnodb2noisevar()返回一个float32。

TF中的

BER计算如下:

ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))

和在PT中:

wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)

我认为BER计算得很好,但主要问题是我不知道如何将sess.run转换为PyTorch,也无法完全理解其功能。

有人可以帮助我吗?

谢谢

1 个答案:

答案 0 :(得分:1)

您可以在PyTorch中执行相同的操作,但是在ber方面更容易:

ber = torch.mean((x != x_hat).float())

就足够了。

是的,PyTorch不需要它,因为它基于动态图构造(与Tensorflow的静态方法不同)。

tensorflow中,sess.run用于将值输入到创建的图形中;此处tf.Placeholder(代表用户可以“注入”其数据的节点的图形变量)名为batch_size的{​​{1}}和samples的{​​{1} }。

将其转换为PyTorch通常很简单,因为您不需要在会话中使用任何类似图形的方法。只需将神经网络(或类似物)与正确的输入(例如noise_varebnodb2noisevar(ebno_db, coderate)一起使用,就可以了。您必须检查图形(以便从samplesnoise_var构造ber),然后在PyTorch中重新实现它。

此外,在深入研究框架之前,请先检查PyTorch introductory tutorials,以了解框架。