脾气暴躁:在每行中的某些索引之后替换行中的其余列

时间:2019-08-27 04:22:21

标签: python numpy

我在3维ndarray中捕获了一些数据,其维度为:时间刻度,样本数和每个样本的值数。

通过先前的操作,我知道在某个时间刻度上,每个样本编号是否都变为无效。如果从未发生这种情况,则将数字设置为-1。在其他情况下,它指示样本变为无效的时间间隔。

我想做的是要么将其余的列都清空,要么将无效数据(包括和位于右边)的列设置为nans,或者使用某种掩蔽或索引技术来生成仅将数据保留在左侧。

我已经阅读或发现了与类似问题有关的参考,这些问题涉及花式索引,slice(),布尔数组和掩码数组,但是我没有找到实现目标的方法。

import numpy as np

# dimensions are timestep, sample, and values per sample.  To make it easy, let's 
# do 3 time steps, 4 samples, and 2 values per sample.
data = np.array( [
    [ # Timestep 0
        [ 1, 2 ],  # Sample 1
        [ 3, 4 ],  #        2
        [ 5, 6 ],  #        3
        [ 7, 8 ],  #        4
    ],
    [ # Timestep 1
        [ 1, 2 ], 
        [ 3, 4 ],
        [ 5, 6 ],
        [ 7, 8 ],
    ],
    [ # Timestep 2
        [ 1, 2 ], 
        [ 3, 4 ],
        [ 5, 6 ],
        [ 7, 8 ],
    ],
])

每个样本可能在一个时间步上变得无效。如果时间步永不无效,则值为-1。

invalid_at = [ 
        0, # becomes invalid at timestep 0
        2, #                             2
        1, #                             1 
        -1 # Never is invalid
        ]

例如,如果我们用nan替换无效值,那么结果数组应该像

data = np.array( [
    [ # Timestep 0
        [ n, n ],  # Sample 1
        [ 3, 4 ],  #        2
        [ 5, 6 ],  #        3
        [ 7, 8 ],  #        4
    ],
    [ # Timestep 1
        [ n, n ], 
        [ 3, 4 ],
        [ n, n ],
        [ 7, 8 ],
    ],
    [ # Timestep 2
        [ n, n ], 
        [ n, n ],
        [ n, n ],
        [ 7, 8 ],
    ],
])

我遇到的主要困难是我有开始索引,但是我找不到一种方法来创建切片(使用花式索引或其他方法)以供我分配。

例如,以下操作无效:

data[ :, invalid_at:-1, : ] = np.nan

我希望会发生的是,对at数组中的无效对象进行评估并生成每行切片。

我可以使用for循环来做到这一点,但是我宁愿将其向量化以提高速度和更高的可伸缩性。有什么想法吗?

1 个答案:

答案 0 :(得分:1)

有两种方法可以做到这一点。主要问题是您尝试应用的索引参差不齐。

如果样本数量很少,则可以用相对较少的额外开销循环遍历它们。此选项非常简单,并支持简单的切片索引,这通常是最快的索引类型,因为它不需要额外的数据或掩码副本:

for sample, step in enumerate(invalid_at):
    if step < 0:
        continue
    data[step:, sample, :] = np.nan

如果您真的需要一步执行此操作,则可以构造一个蒙版并应用它。数组具有维度(时间步长,样本,x)。遮罩仅需要前两个尺寸。您需要设置一个条件,例如“如果某个元素在时间步t上并且sample大于或等于invalid_at[t],则将该元素设置为True。该条件可以应用于一对广播数组:一个用于时间步,一个用于样本:

trange = np.arange(data.shape[0]).reshape(-1, 1)
srange = np.array(invalid_at).reshape(1, -1)
srange[srange == -1] = data.shape[0]
mask = (trange >= srange)
data[mask, :] = np.nan

这仅在您为dtype=np.float显式设置data或类似设置的情况下有效,因为当前定义的整数to不支持NaN。