以下代码中存在pytorch不匹配错误:
# Visualization
import matplotlib.pyplot as plt
digit_size = 28
z_sample = np.random.rand(1, latent_dim) # random
z_sample.dtype = 'float32'
z_sample = torch.tensor(z_sample)
plt.figure(figsize=(20, 1))
for i in range(10):
c = np.zeros((1, cond_dim), dtype='float32')
c[i] = 1
c = torch.tensor(c)
cvae.eval()
with torch.no_grad():
img = cvae.decoder(z_sample, c)
# reshape (if needed)
plt.subplot(1, i+1)
plt.axis('off')
plt.imshow(img, cmap='Greys_r',)
plt.show()
##### test #####
print(np.size(z_sample))
print(latent_dim)
我定义的latent_dim
等于2。
我知道错误来自哪里。
来自以下几行:
z_sample = np.random.rand(1, latent_dim) # random
z_sample.dtype = 'float32'
z_sample = torch.tensor(z_sample)
在第一行中,我定义了一个1 * 2随机向量 但是在将其从“数组”更改为“张量”后,它给了我1 * 4向量,因此出现了不匹配错误。
谁知道我该如何解决? 它也应该仍然是“ float32”。
编辑: CVAE是:
class CVAE(nn.Module):
def __init__(self, x_dim, h_dim1, h_dim2, z_dim, c_dim):
super(CVAE, self).__init__()
# encoder part
self.fc1 = nn.Linear(x_dim + c_dim, h_dim1)
self.fc2 = nn.Linear(h_dim1, h_dim2)
self.fc31 = nn.Linear(h_dim2, z_dim)
self.fc32 = nn.Linear(h_dim2, z_dim)
# decoder part
self.fc4 = nn.Linear(z_dim + c_dim, h_dim2)
self.fc5 = nn.Linear(h_dim2, h_dim1)
self.fc6 = nn.Linear(h_dim1, x_dim)
def encoder(self, x, c):
concat_input = torch.cat((x, c), 1)
h = F.relu(self.fc1(concat_input))
h = F.relu(self.fc2(h))
return self.fc31(h), self.fc32(h)
def sampling(self, mu, log_var):
std = torch.exp(0.5*log_var)
eps = torch.randn_like(std)
return eps.mul(std).add(mu) # return z sample
def decoder(self, z, c):
concat_input = torch.cat((z, c), 1)
h = F.relu(self.fc4(concat_input))
h = F.relu(self.fc5(h))
return F.sigmoid(self.fc6(h))
def forward(self, x, c):
mu, log_var = self.encoder(x.view(-1, 784), c)
z = self.sampling(mu, log_var)
return self.decoder(z, c), mu, log_var
模型是:
# Create Model (change None)
cond_dim = train_loader.dataset.train_labels.unique().size(0)
latent_dim = 2
cvae = CVAE(x_dim=28*28, h_dim1=512, h_dim2=256, z_dim=latent_dim, c_dim=cond_dim)
# Device setting
cvae = cvae.to(device)
,错误是:
RuntimeError:大小不匹配,m1:[1 x 14],m2:[12 x 256],位于/pytorch/aten/src/TH/generic/THTensorMath.cpp:197