class RBM():
def __init__(self, nv, nh):
self.W = torch.randn(nh, nv)
self.a = torch.randn(1, nh)
self.b = torch.randn(1, nv)
def sample_h(self, x):
wx = torch.mm(x, self.W.t())
activation = wx + self.a.expand_as(wx)
p_h_given_v = torch.sigmoid(activation)
return p_h_given_v, torch.bernoulli(p_h_given_v)
def sample_v(self, y):
wy = torch.mm(y, self.W)
activation = wy + self.b.expand_as(wy)
p_v_given_h = torch.sigmoid(activation)
return p_v_given_h, torch.bernoulli(p_v_given_h)
def train(self, v0, vk, ph0, phk):
self.W += torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)
self.b += torch.sum((v0 - vk), 0)
self.a += torch.sum((ph0 - phk), 0)
错误:
火车上的(self,v0,vk,ph0,phk)
19 return p_v_given_h, torch.bernoulli(p_v_given_h)
20 def train(self, v0, vk, ph0, phk):
---> 21 self.W += torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)
22 self.b += torch.sum((v0 - vk), 0)
23 self.a += torch.sum((ph0 - phk), 0)
RuntimeError:张量(1682)的扩展大小必须与非单维度1上的现有大小(100)相匹配
答案 0 :(得分:2)
print(rbm.W.size())
将向您显示torch.Size([100,1682])
print((torch.mm(v0.t(), ph0)-torch.mm(vk.t(), phk)).size())
将向您显示torch.Size([1682,100])
所以看起来应该像这样:
self.W += (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t()
代替self.W += torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)
答案 1 :(得分:0)
更改此行:
self.W += torch.mm (v0.t(), ph0) - torch.mm (vk.t(), phk)
对此:
self.W += (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t()
答案 2 :(得分:0)
尝试一下...
reduce
答案 3 :(得分:0)
每当出现此类尺寸错误时,请尝试在所有变量的前面编写尺寸。如上所述,您必须在火车功能的第一线进行运输。
class RBM():
def __init__(self, nv, nh):
self.W = torch.randn(nh, nv) #100*1682
self.a = torch.randn(1, nh) #1*100
self.b = torch.randn(1, nv) #1*1682
def sample_h(self, x):
wx = torch.mm(x, self.W.t()) #100*1682 * 1682*100 = 100*100
activation = wx + self.a.expand_as(wx)
p_h_given_v = torch.sigmoid(activation) #100*100
return p_h_given_v, torch.bernoulli(p_h_given_v)
def sample_v(self, y):
wy = torch.mm(y, self.W) #100*100 * 100*1682 = 100*1682
activation = wy + self.b.expand_as(wy)
p_v_given_h = torch.sigmoid(activation) #100*1682
return p_v_given_h, torch.bernoulli(p_v_given_h)
def train(self, v0, vk, ph0, phk):
self.W += (torch.mm(v0.t(), ph0)- torch.mm(vk.t(), phk)).t() #100*1682
#!= 1682*100 * 100*100 - 1682*100 * 100*100 = 1682*100
self.b += torch.sum((v0-vk), 0)
self.a += torch.sum((ph0-phk), 0)