Google JAX 1D卷积神经网络

时间:2020-06-13 08:17:55

标签: python cnn jax

我正在尝试使用 stax.GeneralConv()https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#GeneralConv)在 Google Jax 中实现一维卷积神经网络。 我有一个带有18个输入的1维输入数组和带有6个条目的输出数组。我想实现一个内核宽度为3的CNN,如下所示:

init_random_params, conv_net = stax.serial(
    GeneralConv(('NC','IO','NC'),1,(3,),padding='SAME'), # dimension_numbers = ('NC','IO','NC')
    LogSoftmax,
    Dense(6),
)

具有初始网络参数:

rng = jax.random.PRNGKey(0)
_, init_params = init_random_params(rng, (18,))

但是出现以下错误:

stax.py", line 75, in <listcomp>
    next(filter_shape_iter) for c in rhs_spec]

IndexError: tuple index out of range

stax要求维度编号 rhs_spec 至少长2个字符,但我使用的是1维过滤器。有人知道如何解决这个问题吗?

1 个答案:

答案 0 :(得分:0)

我自己还没有尝试过,但是我希望一维卷积仍然需要一个方向进行卷积,例如

Conv2d = functools.partial(GeneralConv, ('NHWC', 'HWIO', 'NHWC'))
Conv1d = functools.partial(GeneralConv, ('NHC', 'HIO', 'NHC'))

换句话说,将W轴放下,从2d卷积到1d卷积。

NHC对应的输入形状为(batch_size, sequence_length, num_channels)

请注意,即使通道数可能为1,您仍需要包括该轴,因为GeneralConv沿num_channels = input_shape['NHC'.index('C')]的行进行索引查找。