这是我的代码,为了找到错误,我在 init 函数中放置了一个输入张量以测试Conv1d网络的输出
'''
class NLPBase(NNBase):
def __init__(self, num_inputs, recurrent=False, hidden_size=64):
super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size)
if recurrent:
num_inputs = hidden_size
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), np.sqrt(2))
testin = torch.randn(1,1,300)
self.actor = init_(nn.Conv1d(in_channels = 1, out_channels = 60, kernel_size = 3))
testout = self.actor(testin)
print("testout size:",testout.size())
self.critic = nn.Sequential(
init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(),
init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())
self.critic_linear = init_(nn.Linear(hidden_size, 1))
self.train()
def forward(self, inputs, rnn_hxs, masks):
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), np.sqrt(2))
x = inputs.reshape(1,1,300)
print("x shape:",x.size())
hidden_critic = self.critic(x)
hidden_actor = self.actor(x)
print("hidden_actor size:",hidden_actor.size())
f_actor = init_(nn.Conv1d(in_channels = 1, out_channels = 60, kernel_size = 3))
f_out = f_actor(x)
print("f_out size:", f_out.size())
#test2 = self.critic_linear(hidden_critic)
#print("test2",test2.size())
return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs
'''
但是,打印的信息如下:
测试尺寸:torch.Size([1,60,298])
x形状:torch.Size([1,1,300])
hidden_actor大小:torch.Size([1,60,60,298])
f_out大小:torch.Size([1,60,298])