theano,索引一个价值相对较小的大矩阵,但已超出边界

时间:2016-12-09 03:06:42

标签: numpy theano

这是我的带有行号的dl4mt(神经机器翻译)的theano代码的一部分。 src_positions是int64的向量,我打印结果,值不超过16。 但是,当我使用src_positions索引注意力watch_mask_时,其形状为(100,100)。它得到了索引超出范围的错误。

这是奇怪的部分:

  1. 首先,attention_mask_和gaussian_mask_具有相同的形状。
  2. 当我使用0.1 * src_positions进行索引时(用注释行4替换第5行)。第8行保持不变,程序运行良好......
  3. 更奇怪的是,当我用注释行7替换第8行,但是保持第5行不变时,程序仍然可以运行!
  4. 我不确定问题是否......这真的很奇怪。希望有人能给我一些建议。

        1] p_t_s = p_t * sntlens  # n_samples * 1, pt in equation
        2] src_positions = tensor.cast(tensor.floor(p_t_s), 'int64') # (n_samples, 1)
        3] src_positions = src_positions.reshape([src_positions.shape[0], ])
        4] # batch_mask = attention_mask_[tensor.cast(src_positions * 0.1, 'int64')]      # n_sample * maxlen
        5] batch_mask = attention_mask_[src_positions]      # n_sample * maxlen
        6] attn_mask = batch_mask[:, :msk_.shape[0]] * msk_.T      # n_sample * n_timestep 
        7] # batch_gauss_mask = gaussian_mask_[tensor.cast(src_positions * 0.1, 'int64')]      # n_sample * maxlen
        8] batch_gauss_mask = gaussian_mask_[src_positions]   # n_sample * maxlen 
        9] gauss_mask = batch_gauss_mask[:, :msk_.shape[0]] * msk_.T      # n_sample * n_timestep 
    

1 个答案:

答案 0 :(得分:1)

似乎问题是基于src_positions发生的。根据您的描述不会有任何问题。也许src_positions会被您发布的部分以外的代码更改