我正在尝试使用Hierarchical Dirichlet Process实施PyMC3(HDP)主题模型。 HDP图形模型如下所示:
我想出了以下代码:
import numpy as np
import scipy as sp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pymc3 as pm
from theano import tensor as tt
np.random.seed(0)
def stick_breaking(beta):
portion_remaining = tt.concatenate([[1], tt.extra_ops.cumprod(1 - beta)[:-1]])
return beta * portion_remaining
def main():
#load data
data = np.array([[1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0]])
Wd = [len(doc) for doc in data]
#HDP parameters
T = 10 # top-level truncation
K = 2 # group-level truncation
V = 4 # number of words
D = 3 # number of documents
with pm.Model() as model:
#top-level stick breaking
gamma = pm.Gamma('gamma', 1., 1.)
beta_prime = pm.Beta('beta_prime', 1., gamma, shape=T)
beta = pm.Deterministic('beta', stick_breaking(beta_prime))
#group-level stick breaking
alpha = pm.Gamma('alpha', 1., 1.)
pi_prime = pm.Beta("pi_prime", 1, alpha, shape=K) #Sethuraman's stick breaking
#pi_prime = [pm.Beta("pi_prime_%s_%s" %(j,k), alpha*(beta[k]), alpha*(1-np.sum(beta[:k+1])), shape=1)
# for j in range(K) for k in range(T)] #Teh's stick breaking
pi = pm.Deterministic('pi', stick_breaking(pi_prime))
#top-level DP
H = pm.Dirichlet("H", a=np.ones(V), shape=V)
phi_top = pm.Multinomial('phi_top', n=np.sum(Wd), p=H, shape=(T,V))
G0 = pm.Mixture('G0', w=beta, comp_dists=phi_top)
#group-level DP
phi_group = [pm.Multinomial('phi_group_%s' %j, n=Wd[j], p=G0) for j in range(D)]
Gj = [pm.Mixture('G_%s' %j, w=pi, comp_dists=phi_group[j]) for j in range(D)]
#likelihood
w = [pm.Categorical("w_%s_%s" %(j,n), p = Gj[j], observed=data[j][n]) for j in range(D) for n in range(Wd[j])]
with model:
trace = pm.sample(2000, n_init=1000, random_seed=42)
pm.traceplot(trace)
plt.show()
if __name__ == '__main__':
main()
但是,我目前正在使用AssertionError
阻止我调试模型的其余部分,它发生在以下行:
phi_top = pm.Multinomial('phi_top', n=np.sum(Wd), p=H, shape=(T,V))
没有关于错误的其他信息。有谁知道如何解决这个问题?