用于顺序数据的GAN的输入/输出形状

时间:2019-04-24 14:43:10

标签: python time-series mxnet generative-adversarial-network

我正在尝试使用GAN进行时间序列预测。我正在使用MXNet / Gluon。因此,我有一个大小为(N,1)的顺序数据,已将其转换为(N-stepsize,stepsize)。现在,我很难理解网络的输入形状。这里是生成器和鉴别器网络的代码。

netG = nn.Sequential()
with netG.name_scope():
    netG.add(nn.Dense(20))
    netG.add(nn.BatchNorm(momentum = 0.8))
    netG.add(nn.Dropout(0.5))
    netG.add(nn.Dense(15))
    netG.add(nn.BatchNorm(momentum = 0.8))
    netG.add(nn.Dropout(0.5))
    netG.add(nn.Dense(20))
    netG.add(nn.BatchNorm(momentum = 0.8))
    netG.add(nn.Dropout(0.5))
    netG.add(nn.Dense(step_size, activation = "tanh"))


#300, 50, 2
#input shape is inferred
netD = nn.Sequential()
with netD.name_scope():
    netD.add(nn.Dense(20))
    netG.add(nn.BatchNorm(momentum = 0.8))
    netD.add(nn.Dense(15, activation='tanh'))
    netG.add(nn.BatchNorm(momentum = 0.8))
    netD.add(nn.Dense(20, activation='tanh'))
    netD.add(nn.Dense(step_size))

谢谢。

1 个答案:

答案 0 :(得分:0)

您可以使用以下代码检查张量形状: print(mx.viz.print_summary(netG(mx.sym.var('data')), shape={'data':(1,100,10)})) 我在这里假设N-stepsize等于100,stepssize等于10。

鉴别器中有2个错误:您将Batchnorm图层添加到netG而不是netD