我试图在CNTK中实现一个关注Seq2Seq模型,这与CNTK Tutorial 204非常相似。但是,一些小的差异会导致各种问题和错误消息,我不明白。这里有很多问题,可能是相互关联的,都源于我不理解的一些事情。
注意(如果它很重要)。我的输入数据来自MinibatchSourceFromData
,是从适合RAM的NumPy数组创建的,我不会将它存储在CTF中。
ins = C.sequence.input_variable(input_dim, name="in", sequence_axis=inAxis)
y = C.sequence.input_variable(label_dim, name="y", sequence_axis=outAxis)
因此,形状为[#, *](input_dim)
和[#, *](label_dim)
。
问题1:当我运行CNTK 204 Tutorial并使用.dot
将其图表转储到cntk.logging.plot
文件时,我看到其输入形状为{{ 1}}。这怎么可能?
[#](-2,)
)在哪里消失? 问题2:在同一个教程中,我们有*
。我不明白这一点。在我的模型中有2个动态轴和1个静态,所以"第三个到最后"轴将是attention_axis = -3
,即批轴。但绝对不应该在批量轴上计算注意力
我希望查看教程代码中的实际轴可以帮助我理解这一点,但上面的#
问题让这更令人困惑。
将[#](-2,)
设置为attention_axis
会出现以下错误:
-2
在创建培训时间模型期间:
RuntimeError: Times: The left operand 'Placeholder('stab_result', [#, outAxis], [128])'
rank (1) must be >= #axes (2) being reduced over.
其中def train_model(m):
@C.Function
def model(ins: InputSequence[Tensor[input_dim]],
labels: OutputSequence[Tensor[label_dim]]):
past_labels = Delay(initial_state=C.Constant(seq_start_encoding))(labels)
return m(ins, past_labels) #<<<<<<<<<<<<<< HERE
return model
是解码器中最后stab_result
层之前的Stabilizer
。我可以在点文件中看到,在Dense
实现的中间出现了大小为1的虚假尾随维度。
将AttentionModel
设置为attention_axis
会出现以下错误:
-1
其中64是我的RuntimeError: Binary elementwise operation ElementTimes: Left operand 'Output('Block346442_Output_0', [#, outAxis], [64])'
shape '[64]' is not compatible with right operand
'Output('attention_weights', [#, outAxis], [200])' shape '[200]'.
,200是我的attention_dim
。据我了解,注意模型中的元素attention_span
绝对不应该将这两者合并在一起,因此*
绝对不是正确的轴。
问题3:我的理解是否正确?什么是正确的轴,为什么它导致上述两个例外中的一个?
感谢您的解释!
答案 0 :(得分:3)
首先,一些好消息:最新版本的AttentionModel中已经修复了一些问题(几天后CNTK 2.2通常会提供):
attention_span
或attention_axis
。如果您没有指定它们并将它们保留为默认值,则会在整个序列中计算注意力。事实上,这些论点已被弃用。关于你的问题:
维度不是负面的。我们在不同的地方使用某些负数表示某些事物:-1是基于第一个小批量推断一次的维度,-2是我认为占位符的形状,-3是将被推断的维度每个小批量(例如当您将可变大小的图像提供给卷积时)。我想如果你在第一个小批量之后打印图形,你应该看到所有的形状都是具体的。
attention_axis
是应该隐藏的实现细节。基本上attention_axis=-3
将创建(1,1,200)的形状,attention_axis=-4
将创建(1,1,1,200)的形状,依此类推。一般来说,任何超过-3的东西都不能保证工作,任何小于-3的东西只会增加1个而没有任何明显的好处。当然,好消息是你可以在最新的大师中忽略这个论点。
TL; DR:如果您在掌握(或在几天内从CNTK 2.2开始),请替换AttentionModel(attention_dim, attention_span=200, attention_axis=-3)
AttentionModel(attention_dim)
。它更快,不包含令人困惑的参数。从CNTK 2.2开始,不推荐使用原始API。