避免在mxnet中的BucketingModule中的某些层之间共享权重?

时间:2017-12-12 23:22:40

标签: mxnet softmax

我正在使用BucketingModule一起训练多个小型号/机器人。这里,桶密钥是bot_id。但是,每个机器人都有单独的目标标签/类集(因此,每个机器人的softmax层大小不同)。

有没有办法在mxnet中训练这样的模型,我想在所有层中共享所有层的权重,但在所有机器人之间共享一个(softmax)?

如何使用sym_gen方法初始化此类模型? 如果在sym_gen方法中,对于Softmax层我指定num_hidden=size_dict[bot],即

pred = mx.sym.FullyConnected(data=pred, num_hidden=len(size_dict[bot]), name='pred')
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

我收到错误:

  

推断的形状与shared_exec.arg_array的形状

不匹配

这是有道理的,因为每个机器人都有不同数量的目标类。

1 个答案:

答案 0 :(得分:2)

此问题已在此处发布并解决:https://github.com/apache/incubator-mxnet/issues/9042

您可以使sym_gen(default_bucket_key)返回包含所有这些不同形状的FC层的“主网络”,并且sym_gen(other_keys)返回具有一个特定FC的主网络的子集。请注意,对于主网络,您可能需要使用mx.sym.Group将所有输出组合在一起,因此只返回一个符号。