以下是一个例子:
所以,0000,0040,0111,4455还可以,但5555,4555,4466还不行。
我想要的是找到顺序中的2345是什么? (从零开始索引)`
例如,序数中0001为“1”。同样,0010是“5”。
可以通过计算,
(5 * 6 * 6 * 1)* 2 +(6 * 6 * 1)* 3 +(6 * 1)* 4 +(1)* 5 = 497
我在Python
import numpy as np
def find_real_index_of_state(state, num_cnt_in_each_digit):
"""
parameter
=========
state(str)
num_cnt_in_each_digit(list) : the number of number in each digit
"""
num_of_digit = len(state)
digit_list = [int(i) for i in state]
num_cnt_in_each_digit.append(1)
real_index = 0
for i in range(num_of_digit):
real_index += np.product(num_cnt_in_each_digit[num_of_digit-i:]) * digit_list[num_of_digit-i-1]
return real_index
find_real_index_of_state("2345", [5,5,6,6])
其结果与497相同。
问题是,这个功能真的很慢。我需要更快的版本,但这个是我能想到的最好的。
我真的需要你的建议来改善它的表现。 (例如矢量化等)
由于
答案 0 :(得分:1)
这是一种矢量化方法,利用np.cumprod
执行迭代// output headers so that the file is downloaded rather than displayed
header('Content-Type: text/csv; charset=utf-8');
header('Content-Disposition: attachment; filename=data.csv');
// create a file pointer connected to the output stream
$output = fopen('php://output', 'w');
// output the column headings
fputcsv($output, array('Column 1', 'Column 2', 'Column 3'));
// fetch the data
mysql_connect('localhost', 'username', 'password');
mysql_select_db('database');
$rows = mysql_query('SELECT field1,field2,field3 FROM table');
// loop over the rows, outputting them
while ($row = mysql_fetch_assoc($rows)) fputcsv($output, $row);
,然后np.dot
执行和减少 -
np.product
运行时测试 -
1)原始样本:
def real_index_vectorized(n, count):
num = [int(d) for d in str(n)]
# Or np.array([n]).view((str,1)).astype(int) #Thanks to @Eric
# Or (int(n)//(10**np.arange(len(n)-1,-1,-1)))%10
return np.dot( np.cumprod(count[:0:-1]), num[-2::-1]) + num[-1]
2)更大一点的样本:
In [66]: %timeit find_real_index_of_state("2345",[5,5,6,6])
100000 loops, best of 3: 14.1 µs per loop
In [67]: %timeit real_index_vectorized("2345",[5,5,6,6])
100000 loops, best of 3: 8.19 µs per loop
作为一个矢量化解决方案,当它与具有良好次循环迭代次数的循环版本竞争时,它可以很好地扩展。
答案 1 :(得分:1)
我注意到的第一件事是你不需要重新计算每个循环的所有内容。即你单独计算(5 * 6 * 6 * 1),(6 * 6 * 1),(6 * 1),(1)而不需要计算一次。
def find_real_index_of_state(state,num_cnt_in_each_digit):
factor = 1
total = 0
for digit, num_cnt in zip(reversed(state), reversed(num_cnt_in_each_digit)):
digit = int(digit)
total += digit*factor
factor*= num_cnt
return total
答案 2 :(得分:0)
为了表现,我建议你首先矢量化你的州:
base=np.array([5*6*6,6*6,6,1])
states=np.array(["2345","0010"])
numbers=np.frombuffer(states,np.uint32).reshape(-1,4)-48 # faster
ordinals=(base*numbers).sum(1)
#array([497, 6], dtype=int64)