我正在尝试解决动态编程问题,并且我想出了一个简单的基于循环的算法,该算法基于一系列if
语句填充二维数组,如下所示:
s = # some string of size n
opt = numpy.zeros(shape=(n, n))
for j in range(0, n):
for i in range(j, -1, -1):
if j - i == 0:
opt[i, j] = 1
elif j - i == 1:
opt[i, j] = 2 if s[i] == s[j] else 1
elif s[i] == s[j] and opt[i + 1, j - 1] == (j - 1) - (i + 1) + 1:
opt[i, j] = 2 + opt[i + 1, j - 1]
else:
opt[i, j] = max(opt[i + 1, j], opt[i, j - 1], opt[i + 1, j - 1])
不幸的是,对于较大的N值,此代码非常慢。我发现最好使用内置函数(如numpy.where
和numpy.fill
来填充数组的值)相对于for循环,但是我很难找到任何示例来说明如何使这些函数(或其他优化的numpy
方法)与一系列if
语句一起使用,就像我的算法一样。用内置的numpy
库重写上述代码以使其针对Python进行更好的优化的合适方法是什么?
答案 0 :(得分:1)
我认为np.where和np.fill不能解决您的问题。 np.where用于返回满足特定条件的numpy数组的元素,但是在您的情况下,条件是numpy数组的 NOT VALUES 变量i和j的值。
对于您的特定问题,我建议您使用Cython专门针对较大的N值优化代码。Cython基本上是Python和C之间的接口。Cython的优点在于它可以保留python语法,但是使用C结构对其进行优化。它允许您以类似于C的方式定义变量类型,以加快计算速度。例如,使用Cython将i和j定义为整数将大大加快速度,因为在每次循环迭代时都会检查i和j的类型。
此外,Cython允许您使用C定义经典,快速的2D数组。然后,您可以使用指针快速访问此2D数组,而不必使用numpy数组。对于您而言,opt将是该2D阵列。
答案 1 :(得分:0)
您的if语句和赋值语句的左侧包含对要在循环中修改的数组的引用。这意味着将没有通用的方法将循环转换为数组操作。因此,您陷入了某种for循环。
如果您使用的是更简单的循环:
for j in range(0, n):
for i in range(j, -1, -1):
if j - i == 0:
opt[i, j] = 1
elif j - i == 1:
opt[i, j] = 2
elif s[i] == s[j]:
opt[i, j] = 3
else:
opt[i, j] = 4
您可以构造布尔数组(使用一些broadcasting)来表示您的三个条件:
import numpy as np
# get arrays i and j that represent the row and column indices
i,j = np.ogrid[:n, :n]
# construct an array with the characters from s
sarr = np.fromiter(s, dtype='U1').reshape(1, -1)
cond1 = i==j # result will be a bool arr with True wherever row index equals column index
cond2 = j==i+1 # result will be a bool arr with True wherever col index equals (row index + 1)
cond3 = sarr==sarr.T # result will be a bool arr with True wherever s[i]==s[j]
然后您可以使用numpy.select
来构建所需的opt
:
opt = np.select([cond1, cond2, cond3], [1, 2, 3], default=4)
对于n=5
和s='abbca'
,这将产生:
array([[1, 2, 4, 4, 3],
[4, 1, 2, 4, 4],
[4, 3, 1, 2, 4],
[4, 4, 4, 1, 2],
[3, 4, 4, 4, 1]])
答案 2 :(得分:0)
这是向量化的解决方案。
它将对角线视图创建到输出数组中,这使我们能够在对角线方向上进行累加。
分步说明:
在对角线视图中评估s [i] == s [j]。
仅保留通过右上至左下方向的一系列True连接到主要或第一个子对角线的那些
将所有True替换为2s,但主对角线改为1s。从左下到右上的方向求和
最后,取自下而上和左右方向的累计最大值
这并不是很明显,这与我在许多示例(使用下面的函数stresstest
)上测试的循环代码相同,并且看起来是正确的。对于中等大小的字符串(1-100个字符),速度大约要快7倍。
import numpy as np
def loopy(s):
n = len(s)
opt = np.zeros(shape=(n, n), dtype=int)
for j in range(0, n):
for i in range(j, -1, -1):
if j - i == 0:
opt[i, j] = 1
elif j - i == 1:
opt[i, j] = 2 if s[i] == s[j] else 1
elif s[i] == s[j] and opt[i + 1, j - 1] == (j - 1) - (i + 1) + 1:
opt[i, j] = 2 + opt[i + 1, j - 1]
else:
opt[i, j] = max(opt[i + 1, j], opt[i, j - 1], opt[i + 1, j - 1])
return opt
def vect(s):
n = len(s)
h = (n+1) // 2
s = np.array([s, s]).view('U1').ravel()
opt = np.zeros((n+2*h-1, n+2*h-1), int)
y, x = opt.strides
hh = np.lib.stride_tricks.as_strided(opt[h-1:, h-1:], (2, h, n), (x, x-y, x+y))
p, o, c = np.ogrid[:2, :h, :n]
hh[...] = 2 * np.logical_and.accumulate(s[c+o+p] == s[c-o], axis=1)
np.einsum('ii->i', opt)[...] = 1
hh[...] = hh.cumsum(axis=1)
opt = np.maximum.accumulate(opt[-h-1:None if h == 1 else h-2:-1, h-1:-h], axis=0)[::-1]
return np.maximum.accumulate(opt, axis=1)
def stresstest(n=100):
from string import ascii_lowercase
import random
from timeit import timeit
Tv, Tl = 0, 0
for i in range(n):
s = ''.join(random.choices(ascii_lowercase[:random.randint(2, 26)], k=random.randint(1, 100)))
print(s, end=' ')
assert np.all(vect(s) == loopy(s))
Tv += timeit(lambda: vect(s), number=10)
Tl += timeit(lambda: loopy(s), number=10)
print()
print(f"total time loopy {Tl}, vect {Tv}")
演示:
>>> stresstest(20)
caccbbdbcfbfdcacebbecffacabeddcfdededeeafaebeaeedaaedaabebfacbdd fckjhrmupcqmihlohjog dffffgalbdbhkjigladhgdjaaagelddehahbbhejkibdgjhlkbcihiejdgidljfalfhlaglcgcih eacdebdcfcdcccaacfccefbccbced agglljlhfj mvwlkedblhvwbsmvtbjpqhgbaolnceqpgkhfivtbkwgbvujskkoklgforocj jljiqlidcdolcpmbfdqbdpjjjhbklcqmnmkfckkch ohsxiviwanuafkjocpexjmdiwlcmtcbagksodasdriieikvxphksedajwrbpee mcwdxsoghnuvxglhxcxxrezcdkahpijgujqqrqaideyhepfmrgxndhyifg omhppjaenjprnd roubpjfjbiafulerejpdniniuljqpouimsfukudndgtjggtbcjbchhfcdhrgf krutrwnttvqdemuwqwidvntpvptjqmekjctvbbetrvehsgxqfsjhoivdvwonvjd adiccabdbifigeigdfaieecceciaghadiaigibehdaichfibeaggcgdciahfegefigghgebhddciaei llobdegpmebejvotsr rtnsevatjvuowmquaulfmgiwsophuvlablslbwrpnhtekmpphsenarhrptgbjvlseeqstewjgfhopqwgmcbcihljeguv gcjlfihmfjbkdmimjknamfbahiccbhnceiahbnhghnlleimmieglgbfjbnmemdgddndhinncegnmgmfmgahhhjkg nhbnfhp cyjcygpaaeotcpwfhnumcfveq snyefmeuyjhcglyluezrx hcjhejhdaejchedbce
total time loopy 0.2523909523151815, vect 0.03500175685621798