为什么numpy vs cntk形状不一致?

时间:2017-01-26 03:51:25

标签: cntk

我刚开始学习cntk。但是,我有一个基本问题阻碍了我的进步。我通过了以下测试:

import numpy as np
from cntk import input_variable, plus

def test_simple(self):

    x_input = np.asarray([[1, 2, 2]], dtype=np.int64)
    assert (1, 3) == x_input.shape

    y_input = np.asarray([[5, 3, 3]], dtype=np.int64)
    assert (1, 3) == y_input.shape

    x = input_variable(x_input.shape[1])
    assert (3, ) == x.shape

    y = input_variable(y_input.shape[1])
    assert (3, ) == y.shape

    x_plus_y = plus(x, y)
    assert (3, ) == x_plus_y.shape

    res = x_plus_y.eval({x: x_input, y: y_input})

    assert 6 == res[0, 0, 0]
    assert 5 == res[0, 0, 1]
    assert 5 == res[0, 0, 2]

据我所知,输出的形状为(1,1,3),因为第一轴和第二轴分别是批量和默认动态轴。

但是,为什么我需要将输入变量的形状设置为(3,)而不是(1,3)。使用(1,3)失败。

为什么图中输入节点的形状与用作该节点输入的numpy数据之间存在不一致?

谢谢你, 稻谷

1 个答案:

答案 0 :(得分:2)

Function.forward的“论据”描述中对此进行了一些解释。另一种描述是here。你混淆的原因可能是CNTK做了一些“有用的”转换。

如果您将输入指定为(1,3),则需要提供(1,3)数组的列表,以防小批量没有序列轴或(x,1,3)数组的列表具有序列轴的小批量的情况(其中x对于小批量中的每个序列可能是不同的)。类似地,如果将输入指定为(3,),则需要提供(3,)向量列表或(x,3)向量列表。

混淆可能源于未提供清单的情况。在那种情况下,CNTK在所提供的张量的引导轴上进行迭代,并从这些元素中创建列表,例如a(5,1,3)张量成为一批5个元素,每个元素的形状为(1,3)。