MXNet中使用RNN的形状不匹配 - R.

时间:2017-06-17 14:35:27

标签: r deep-learning recurrent-neural-network mxnet

我正在尝试在MXNet中使用RNN进行分类。我的数据大致类似于我创建的矩阵m0和m1。 m0表示例如设备随时间的能量消耗,而m1是我的标签,用于说明设备如何分类(例如,在这种情况下为二进制)。 我的目标是通过查看一段时间内的能耗来检测设备的类别。 我不断收到有关形状不匹配的错误,并且无法通过更改输入参数找到解决方案。您可以在下面看到我的代码和错误消息。 我很感激有关如何处理这个问题的任何建议。

require(mxnet)

m0 <- matrix(runif(200*100), 100, 200)
m1 <- matrix(round(runif(1*200)), 1, 200)

num.round      <- 10
update.period  <- 1
num.rnn.layer  <- 1
seq.len        <- 100
num.hidden     <- 1
num.embed      <- 1
num.label      <- 1
batch.size     <- 1
input.size     <- 1
learning.rate  <- 0.1

X.train <- list(data = m0, label = m1)

model <- mx.rnn(train.data = X.train,
                eval.data = NULL,
                num.rnn.layer = num.rnn.layer,
                seq.len = seq.len,
                num.hidden = num.hidden,
                num.embed = num.embed,
                num.label = num.label,
                batch.size = batch.size,
                input.size = input.size,
                ctx = mx.cpu(),
                num.round = num.round,
                update.period = update.period,
                initializer = mx.init.uniform(0.1),
                learning.rate = learning.rate)
  

[16:07:02] d:\ program files   (86)\詹金斯\工作空间\ mxnet \ mxnet \ SRC \操作\ tensor./matrix_op-inl.h:144:   不推荐使用target_shape。

     

[16:07:02] d:\ program files   (86)\詹金斯\工作空间\ mxnet \ mxnet \ SRC \操作\ tensor./matrix_op-inl.h:144:   不推荐使用target_shape。

     

[16:07:02] d:\ program files   (86)\詹金斯\工作空间\ mxnet \ mxnet \ SRC \操作\ tensor./matrix_op-inl.h:144:   不推荐使用target_shape。

     

[16:07:02] D:\ Program Files   (86)\詹金斯\工作空间\ mxnet \ mxnet \ DMLC-芯\包括\ DMLC / logging.h:304:

     

[16:07:02] D:\ Program Files   (x86)\ Jenkins \ workspace \ mxnet \ mxnet \ src \ ndarray \ ndarray.cc:299:检查   失败:from.shape()== to-&gt; shape()操作数形状   mismatchfrom.shape =(1,1)to.shape =(1,100)错误   exec $ update.arg.arrays(arg.arrays,match.name,skip.null):

     

[16:07:02] D:\ Program Files   (x86)\ Jenkins \ workspace \ mxnet \ mxnet \ src \ ndarray \ ndarray.cc:299:检查   失败:from.shape()== to-&gt; shape()操作数形状   mismatchfrom.shape =(1,1)to.shape =(1,100)

1 个答案:

答案 0 :(得分:1)

尺寸不匹配的原因是您传递的label维度与序列长度不匹配。 RNN对序列的每个抽头都有一个输出,因此如果您的长度为100,它将有100个输出,每个时间步长一个。您可以通过将m1设置为matrix(round(runif(100*200)), 100, 200)来修复此错误,但您无法使用简化的mx.rnn()接口执行所需操作(即预测整个序列的一个数字)。您需要根据代码here实施自己的网络。为了实现您正在寻找的单输出,您可以丢弃除最后时间步之外的所有输出,并通过Softmax层运行该输出。