我正在尝试将tf的代码转换为pytorch。 我卡住的部分代码是sess.run。据我所知,pytorch不需要它,但是我找不到复制它的方法。我附上了代码。
TF:
ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
for i in range(ebnos_db.shape[0]):
ebno_db = ebnos_db[i]
bers_no_training[i] += sess.run(ber, feed_dict={
batch_size: samples,
noise_var: ebnodb2noisevar(ebno_db, coderate)
})
bers_no_training /= epochs
samples是一个int32,而ebnodb2noisevar()返回一个float32。
TF中的BER计算如下:
ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))
和在PT中:
wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)
我认为BER计算得很好,但主要问题是我不知道如何将sess.run转换为PyTorch,也无法完全理解其功能。
有人可以帮助我吗?
谢谢
答案 0 :(得分:1)
您可以在PyTorch中执行相同的操作,但是在ber
方面更容易:
ber = torch.mean((x != x_hat).float())
就足够了。
是的,PyTorch不需要它,因为它基于动态图构造(与Tensorflow的静态方法不同)。
在tensorflow
中,sess.run
用于将值输入到创建的图形中;此处tf.Placeholder
(代表用户可以“注入”其数据的节点的图形变量)名为batch_size
的{{1}}和samples
的{{1} }。
将其转换为PyTorch通常很简单,因为您不需要在会话中使用任何类似图形的方法。只需将神经网络(或类似物)与正确的输入(例如noise_var
和ebnodb2noisevar(ebno_db, coderate)
一起使用,就可以了。您必须检查图形(以便从samples
和noise_var
构造ber
),然后在PyTorch中重新实现它。
此外,在深入研究框架之前,请先检查PyTorch introductory tutorials,以了解框架。