我正在尝试使用 stax.GeneralConv()(https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv)在 Google Jax 中实现一维卷积神经网络。 我有一个带有18个输入的1维输入数组和带有6个条目的输出数组。我想实现一个内核宽度为3的CNN,如下所示:
init_random_params, conv_net = stax.serial(
GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
LogSoftmax,
Dense(6),
)
具有初始网络参数:
rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))
但是出现以下错误:
stax.py", line 75, in <listcomp>
next(filter_shape_iter) for c in rhs_spec]
IndexError: tuple index out of range
stax要求维度编号 rhs_spec 至少长2个字符,但我使用的是1维过滤器。有人知道如何解决这个问题吗?
答案 0 :(得分:0)
我自己还没有尝试过,但是我希望一维卷积仍然需要一个方向进行卷积,例如
Conv2d = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
Conv1d = functools.partial(GeneralConv, ('NHC', 'HIO', 'NHC'))
换句话说,将W
轴放下,从2d卷积到1d卷积。
与NHC
对应的输入形状为(batch_size, sequence_length, num_channels)
。
请注意,即使通道数可能为1,您仍需要包括该轴,因为GeneralConv
沿num_channels = input_shape['NHC'.index('C')]
的行进行索引查找。