如何在chainer中实现批量线性链接,为批处理中的每个示例支持不同的权重?

时间:2017-12-01 08:38:19

标签: python chainer

我们使用chainer.functions.linear来计算y=Wx+b

在我的情况下,我必须实现一个更多维度的线性链接。

假设输入示例为(c, x),则所需的输出为y = W_c x + b。让我们忽略偏见并将其y = W_c x {c}的基数是事先已知的(通常是样本类别)。

理论上,W参数可以实现为3-d张量(C, y_dims, x_dims)。但还有什么?我是否必须遍历批处理并提取W_c形状(y_dims, x_dims)并仅针对该functions.linear形状的示例调用(1, x_dims)

1 个答案:

答案 0 :(得分:0)

好吧,我自己找到了一个问题的解决方案。

让数据具有如下形状,

  • W: (C, y_dims, x_dims)
  • x: (batch, x_dims)
  • c: (batch, 1)

首先,我必须为批次中的每个x获得一个权重矩阵:

W_c = chainer.functions.get_item(W, chainer.as_variable(c).data)
y = chainer.functions.batch_matmul(W_c, chainer.expand_dims(x, 2)) // in shape (batch, y_dims, 1)

因此,此处的关键功能是get_item,它同时接受numpy.ndarraycupy.ndarray,但 chainer.Variable。它的工作方式与numpy.take类似,但可以区分并节省大量工作。