如何有效地找出序数中的数字是多少?

时间:2017-02-25 13:55:58

标签: python numpy vectorization

以下是一个例子:

  1. 4位数
  2. 第一个,第二个数字的范围是:0~5(总共六个数字)
  3. 第三,第四位数的范围是:0~4(总共五位数)
  4. 所以,0000,0040,0111,4455还可以,但5555,4555,4466还不行。

    我想要的是找到顺序中的2345是什么? (从零开始索引)`

    例如,序数中0001为“1”。同样,0010是“5”。

    可以通过计算,

      

    (5 * 6 * 6 * 1)* 2 +(6 * 6 * 1)* 3 +(6 * 1)* 4 +(1)* 5 = 497

    我在Python

    中创建了一个函数
    import numpy as np
    
    def find_real_index_of_state(state, num_cnt_in_each_digit):
        """
        parameter
        =========
        state(str) 
        num_cnt_in_each_digit(list) : the number of number in each digit
        """
        num_of_digit = len(state)
        digit_list = [int(i) for i in state]   
    
        num_cnt_in_each_digit.append(1)
    
        real_index = 0
        for i in range(num_of_digit):
            real_index += np.product(num_cnt_in_each_digit[num_of_digit-i:]) * digit_list[num_of_digit-i-1]
        return real_index
    
    find_real_index_of_state("2345", [5,5,6,6])
    

    其结果与497相同。

    问题是,这个功能真的很慢。我需要更快的版本,但这个是我能想到的最好的。

    我真的需要你的建议来改善它的表现。 (例如矢量化等)

    由于

3 个答案:

答案 0 :(得分:1)

这是一种矢量化方法,利用np.cumprod执行迭代// output headers so that the file is downloaded rather than displayed header('Content-Type: text/csv; charset=utf-8'); header('Content-Disposition: attachment; filename=data.csv'); // create a file pointer connected to the output stream $output = fopen('php://output', 'w'); // output the column headings fputcsv($output, array('Column 1', 'Column 2', 'Column 3')); // fetch the data mysql_connect('localhost', 'username', 'password'); mysql_select_db('database'); $rows = mysql_query('SELECT field1,field2,field3 FROM table'); // loop over the rows, outputting them while ($row = mysql_fetch_assoc($rows)) fputcsv($output, $row); ,然后np.dot执行和减少 -

np.product

运行时测试 -

1)原始样本:

def real_index_vectorized(n, count):
    num = [int(d) for d in str(n)]
        # Or np.array([n]).view((str,1)).astype(int) #Thanks to @Eric
        # Or (int(n)//(10**np.arange(len(n)-1,-1,-1)))%10
    return np.dot( np.cumprod(count[:0:-1]), num[-2::-1]) + num[-1]

2)更大一点的样本:

In [66]: %timeit find_real_index_of_state("2345",[5,5,6,6])
100000 loops, best of 3: 14.1 µs per loop

In [67]: %timeit real_index_vectorized("2345",[5,5,6,6])
100000 loops, best of 3: 8.19 µs per loop

作为一个矢量化解决方案,当它与具有良好次循环迭代次数的循环版本竞争时,它可以很好地扩展。

答案 1 :(得分:1)

希望我能正确理解你。

我注意到的第一件事是你不需要重新计算每个循环的所有内容。即你单独计算(5 * 6 * 6 * 1),(6 * 6 * 1),(6 * 1),(1)而不需要计算一次。

def find_real_index_of_state(state,num_cnt_in_each_digit):

    factor = 1

    total = 0

    for digit, num_cnt in zip(reversed(state), reversed(num_cnt_in_each_digit)):

        digit = int(digit)

        total += digit*factor

        factor*= num_cnt

    return total

答案 2 :(得分:0)

为了表现,我建议你首先矢量化你的州:

base=np.array([5*6*6,6*6,6,1])
states=np.array(["2345","0010"])
numbers=np.frombuffer(states,np.uint32).reshape(-1,4)-48 # faster
ordinals=(base*numbers).sum(1)
#array([497,   6], dtype=int64)