在numpy中生成随机矩阵,不包含全为1的行

时间:2019-03-05 03:50:45

标签: python arrays numpy

我正在生成一个随机矩阵,

#include <iostream>
#include <string>
#include <vector>

class Player {
public:
    const unsigned    id;
    const std::string color;
    const std::string name;

    void display() const { std::cout << id << ' ' << color << ' ' << name << '\n'; }

    Player() = delete;
    Player(unsigned i, const std::string& c, const std::string& n) :
        id(i),
        color(c),
        name(n)
    {}
};

int main() {
    const char* colors[] = { "red", "black", "blue", "green", "white" };

    std::vector<Player> players;
    for (unsigned i = 0; i < 50; ++i)
        players.emplace_back(Player(
            i + 1,
            colors[i / 10],
            "player" + std::to_string(i + 1)
        ));

    for (const Player& player : players)
        player.display();

    return 0;
}

输出类似

的内容
np.random.randint(2, size=(5, 3))

如何在每行不能包含所有[0,1,0], [1,0,0], [1,1,1], [1,0,1], [0,0,0] 的条件下创建随机矩阵?也就是说,每一行可以是1[1,0,0][0,0,0][1,1,0]]或[1,0,1]或[0,0,1[0,1,0]不能[0,1,1]

感谢您的回答

4 个答案:

答案 0 :(得分:7)

这是一种有趣的方法:

rows = np.random.randint(7, size=(6, 1), dtype=np.uint8)
np.unpackbits(rows, axis=1)[:, -3:]

本质上,您要为每行选择整数0-6,即选择000-110作为二进制。 7将是111(全1)。由于unpackbits的输出是8位数字,您只需要提取二进制数字作为列并获取最后3位数字(您的3列)即可。

输出:

array([[1, 0, 1],
       [1, 0, 0],
       [1, 0, 0],
       [1, 0, 0],
       [0, 1, 1],
       [0, 0, 0]], dtype=uint8)

答案 1 :(得分:4)

如果您始终有3列,则一种方法是显式列出可能的行,然后在其中随机选择,直到您有足够的行为止:

import numpy as np

# every acceptable row
choices = np.array([
    [1,0,0],
    [0,0,0],
    [1,1,0],
    [1,0,1],
    [0,0,1],
    [0,1,0],
    [0,1,1]
])

n_rows = 5
# randomly pick which type of row to use for each row needed
idx = np.random.choice(range(len(choices)), size=n_rows)

# make an array by using the chosen rows
array = choices[idx]

如果这需要归纳为大量的列,则显式列出所有选择将是不切实际的(即使以编程方式创建选择,内存仍然是一个问题;在其中可能的行数呈指数增长列数)。相反,您可以创建一个初始矩阵,然后对所有不可接受的行重新采样,直到没有剩余的行为止。我假设一行仅由1组成,这是不可接受的。但是,很容易将其适应于阈值为1的任意数量的情况。

n_rows = 5
n_cols = 4

array = np.random.randint(2, size=(n_rows, n_cols))
all_1s_idx = array.sum(axis=-1) == n_cols
while all_1s_idx.any():
    array[all_1s_idx] = np.random.randint(2, size=(all_1s_idx.sum(), n_cols))
    all_1s_idx = array.sum(axis=-1) == n_cols

在这里,我们仅对所有不可接受的行进行重新采样,直到没有剩余的行为止。因为所有必需的行都立即被重新采样,所以这应该非常有效。另外,随着列数的增加,行全为1的概率呈指数下降,因此效率不成问题。

答案 2 :(得分:1)

@busybear击败了我,但我还是将其发布,因为它有点笼统:

def not_all(m, k):
    if k>64 or sys.byteorder != 'little':
        raise NotImplementedError
    sample = np.random.randint(0, 2**k-1, (m,), dtype='u8').view('u1').reshape(m, -1)
    sample[:, k//8] <<= -k%8                                                        
    return np.unpackbits(sample).reshape(m, -1)[:, :k]                         

例如:

>>> sample = not_all(1000000, 11)
# sanity checks
>>> unq, cnt = np.unique(sample, axis=0, return_counts=True)
>>> len(unq) == 2**11-1
True
>>> unq.sum(1).max()
10
>>> cnt.min(), cnt.max()
(403, 568)

当我要劫持其他人的答案时,这里是@Nathan接受/拒绝方法的简化版本。

def accrej(m, k):
    sample = np.random.randint(0, 2, (m, k), bool)
    all_ones, = np.where(sample.all(1))
    while all_ones.size:
        resample = np.random.randint(0, 2, (all_ones.size, k), bool)
        sample[all_ones] = resample
        all_ones = all_ones[resample.all(1)]
    return sample.view('u1')

答案 3 :(得分:-3)

使用sum()尝试此解决方案:

import numpy as np

array = np.random.randint(2, size=(5, 3))
for i, entry in enumerate(array):
    if entry.sum() == 3:
        while True:
            new = np.random.randint(2, size=(1, 3))
            if new.sum() == 3:
                continue
            break
        array[i] = new

print(array)

祝你好运,我的朋友!