外层产品如何与chainer一起做?

时间:2017-11-22 07:15:34

标签: chainer

如何将(外部特征向量及其自身的)外部产品包含在chainer中,特别是以与批处理兼容的方式?

2 个答案:

答案 0 :(得分:1)

F.matmul也非常方便。

根据输入的形状,您可以将其与F.expand_dims(当然F.reshape也适用)结合使用,或使用transa / transb个参数。

有关详细信息,请参阅the official documentation of functions

代码

import chainer.functions as F
import numpy as np

print("---")
x = np.array([[[1], [2], [3]], [[4], [5], [6]]], 'f')
y = np.array([[[1, 2, 3]], [[4, 5, 6]]], 'f')
print(x.shape)
print(y.shape)
z = F.matmul(x, y)
print(z)


print("---")
x = np.array([[[1], [2], [3]], [[4], [5], [6]]], 'f')
y = np.array([[[1], [2], [3]], [[4], [5], [6]]], 'f')
print(x.shape)
print(y.shape)
z = F.matmul(x, y, transb=True)
print(z)


print("---")
x = np.array([[1, 2, 3], [4, 5, 6]], 'f')
y = np.array([[1, 2, 3], [4, 5, 6]], 'f')
print(x.shape)
print(y.shape)
z = F.matmul(
    F.expand_dims(x, -1),
    F.expand_dims(y, -1),
    transb=True)
print(z)

输出

---
(2, 3, 1)
(2, 1, 3)
variable([[[  1.   2.   3.]
           [  2.   4.   6.]
           [  3.   6.   9.]]

          [[ 16.  20.  24.]
           [ 20.  25.  30.]
           [ 24.  30.  36.]]])
---
(2, 3, 1)
(2, 3, 1)
variable([[[  1.   2.   3.]
           [  2.   4.   6.]
           [  3.   6.   9.]]

          [[ 16.  20.  24.]
           [ 20.  25.  30.]
           [ 24.  30.  36.]]])
---
(2, 3)
(2, 3)
variable([[[  1.   2.   3.]
           [  2.   4.   6.]
           [  3.   6.   9.]]

          [[ 16.  20.  24.]
           [ 20.  25.  30.]
           [ 24.  30.  36.]]])

答案 1 :(得分:0)

您可以使用F.reshapeF.broadcast_to显式处理数组。

假设您有2-dim数组h的形状(minibatch,feature)。 如果您要计算hh的外部产品,请尝试以下代码。 这是你想要做的吗?

import numpy as np
from chainer import functions as F


def outer_product(h):
    s0, s1 = h.shape
    h1 = F.reshape(h, (s0, s1, 1))
    h1 = F.broadcast_to(h1, (s0, s1, s1))
    h2 = F.reshape(h, (s0, 1, s1))
    h2 = F.broadcast_to(h2, (s0, s1, s1))
    h_outer = h1 * h2
    return h_outer

# test code
h = np.arange(12).reshape(3, 4).astype(np.float32)
h_outer = outer_product(h)
print(h.shape)
print(h_outer.shape, h_outer.data)