我正在尝试用Haskell解决Project Euler problem #92。我最近开始学习Haskell。这是我试图用Haskell解决的第一个Project Euler问题,但是我的代码片段在10分钟内也没有终止。我知道你不直接给我答案,但我应该再次警告我用c ++找回答并不能给出欧拉的答案或解决欧拉的新逻辑。我只是好奇为什么那个人不能快速工作,我该怎么办才能让它更快?
{--EULER 92--}
import Data.List
myFirstFunction 1 = 0
myFirstFunction 89 = 1
myFirstFunction x= myFirstFunction (giveResult x)
giveResult 0 = 0
giveResult x = (square (mod x 10)) + (giveResult (div x 10))
square x = x*x
a=[1..10000000]
main = putStrLn(show (sum (map myFirstFunction a)))
答案 0 :(得分:22)
当然,使用更好的算法可以获得最大的加速。不过,我并没有深入探讨这一点。
因此,让我们专注于改进使用过的算法,而不是真正改变它。
您永远不会给出任何类型签名,因此类型默认为任意精度Integer
。这里的所有内容都很容易适应Int
,没有溢出的危险,所以让我们使用它。添加类型签名myFirstFunction :: Int -> Int
有助于:时间从Total time 13.77s ( 13.79s elapsed)
降至Total time 6.24s ( 6.24s elapsed)
,总分配下降约15倍。对于这种简单的更改,这不是坏事。
您使用div
和mod
。这些总是计算非负余数和相应的商,因此如果涉及一些负数,它们需要一些额外的检查。函数quot
和rem
映射到机器分区指令,它们不涉及此类检查,因此更快一些。如果通过LLVM后端(-fllvm
)进行编译,那么这也会利用您总是除以一个已知数字(10)的事实,并将除法转换为乘法和位移。现在时间:Total time 1.56s ( 1.56s elapsed)
。
我们不是单独使用quot
和rem
,而是使用同时计算两者的quotRem
函数,这样我们就不会重复除法(即使乘法+移位需要一点时间):
giveResult x = case x `quotRem` 10 of
(q,r) -> r*r + giveResult q
这并没有多大收获,只有一点点:Total time 1.49s ( 1.49s elapsed)
。
您正在使用列表a = [1 .. 10000000]
和map
该列表中的函数,然后sum
生成的列表。这是惯用的,整洁的,但不是超快的,因为分配所有这些列表单元和垃圾收集它们也需要时间 - 不是很多,因为GHC 非常擅长这一点,但是将其转换为循环
main = print $ go 0 1
where
go acc n
| n > 10000000 = acc
| otherwise = go (acc + myFirstFunction n) (n+1)
让我们稍微停顿一下:Total time 1.34s ( 1.34s elapsed)
,分配从最后一个列表版本的880,051,856 bytes allocated in the heap
下降到51,840 bytes allocated in the heap
。
giveResult
是递归的,因此无法内联。同样适用于myFirstFunction
,因此每次计算都需要两次函数调用(至少)。我们可以通过将giveResult
重写为非递归包装器和递归本地循环来避免这种情况,
giveResult x = go 0 x
where
go acc 0 = acc
go acc n = case n `quotRem` 10 of
(q,r) -> go (acc + r*r) q
这样可以内联:Total time 1.04s ( 1.04s elapsed)
。
这些是最明显的观点,进一步的改进 - 除了哈马尔在评论中提到的备忘录 - 需要一些思考。
我们现在在
module Main (main) where
myFirstFunction :: Int -> Int
myFirstFunction 1 = 0
myFirstFunction 89 = 1
myFirstFunction x= myFirstFunction (giveResult x)
giveResult :: Int -> Int
giveResult x = go 0 x
where
go acc 0 = acc
go acc n = case n `quotRem` 10 of
(q,r) -> go (acc + r*r) q
main :: IO ()
main = print $ go 0 1
where
go acc n
| n > 10000000 = acc
| otherwise = go (acc + myFirstFunction n) (n+1)
使用-O2 -fllvm
,在此处运行1.04秒,但使用本机代码生成器(仅-O2
),需要3.5秒。这种差异是由于GHC本身并没有将除法转换为乘法和位移的事实。如果我们手动完成,我们可以从本机代码生成器获得相同的性能。
因为我们知道编译器没有的东西,即我们从不在这里处理负数,并且数字不会变大,我们甚至可以产生更好的乘法和移位(会比编译器产生错误的负数或大股息结果,本机代码生成器的时间缩短为0.9秒,LLVM后端的时间缩短为0.73秒:
import Data.Bits
qr10 :: Int -> (Int, Int)
qr10 n = (q, r)
where
q = (n * 0x66666667) `unsafeShiftR` 34
r = n - 10 * q
注意:这要求Int
是64位类型,它不能使用32位Int
,它会产生错误的结果对于否定n
,对于大n
,乘法将溢出。我们正在进入肮脏的黑客领域。我们可以使用Word
代替Int
来减轻肮脏,只留下溢出(n <= 10737418236
Word
与n <= 5368709118
Int
不会发生溢出} #include <stdio.h>
unsigned int myFirstFunction(unsigned int i);
unsigned int giveResult(unsigned int i);
int main(void) {
unsigned int sum = 0;
for(unsigned int i = 1; i <= 10000000; ++i) {
sum += myFirstFunction(i);
}
printf("%u\n",sum);
return 0;
}
unsigned int myFirstFunction(unsigned int i) {
if (i == 1) return 0;
if (i == 89) return 1;
return myFirstFunction(giveResult(i));
}
unsigned int giveResult(unsigned int i) {
unsigned int acc = 0, r, q;
while(i) {
q = (i*0x66666667UL) >> 34;
r = i - q*10;
i = q;
acc += r*r;
}
return acc;
}
,所以在这里我们舒适地处于安全区。时间没有受到影响。
相应的C程序
gcc -O3
执行类似,使用clang -O3
编译,运行0.78秒,<= 7*9²
运行0.71。
在不改变算法的情况下,这几乎就是结束。
现在,算法的一个微小变化就是记忆。如果我们为数字module Main (main) where
import Data.Array.Unboxed
import Data.Array.IArray
import Data.Array.Base (unsafeAt)
import Data.Bits
qr10 :: Int -> (Int, Int)
qr10 n = (q, r)
where
q = (n * 0x66666667) `unsafeShiftR` 34
r = n - 10 * q
digitSquareSum :: Int -> Int
digitSquareSum = go 0
where
go acc 0 = acc
go acc n = case qr10 n of
(q,r) -> go (acc + r*r) q
table :: UArray Int Int
table = array (0,567) $ assocs helper
where
helper :: Array Int Int
helper = array (0,567) [(i, f i) | i <- [0 .. 567]]
f 0 = 0
f 1 = 0
f 89 = 1
f n = helper ! digitSquareSum n
endPoint :: Int -> Int
endPoint n = table `unsafeAt` digitSquareSum n
main :: IO ()
main = print $ go 0 1
where
go acc n
| n > 10000000 = acc
| otherwise = go (acc + endPoint n) (n+1)
建立一个查找表,我们只需要计算每个数字的数字平方和,而不是迭代它直到我们达到1或89,所以让我们回忆一下,
gcc -O3
手动进行记忆而不是使用库会使代码更长,但我们可以根据需要定制代码。我们可以使用一个未装箱的数组,我们可以省略数组访问的边界检查。两者都显着加快了计算速度。本机代码生成器的时间现在为0.18秒,LLVM后端的时间为0.13秒。相应的C程序用clang -O3
编译运行0.16秒,用ghc -O2 -fllvm
编译0.145秒(Haskell击败C,w00t!)。
然而,所使用的算法不能很好地扩展,比线性更差,并且对于10 8 的上限(具有适当调整的记忆限制),它在1.5秒内运行(clang -O3
),resp。 1.64秒(gcc -O3
)和1.87秒(10 = 1×3² + 1×1²
10 = 2×2² + 2×1²
10 = 1×2² + 6×1²
10 = 10×1²
)[本机代码生成器的2.02秒]。
使用不同的算法,通过将这些数字划分为数字的平方和来计算序列以1结尾的数字(直接产生1的唯一数字是10的幂。我们可以写
<= 10^10
从第一次开始,我们获得13,31,103,130,301,310,1003,1030,1300,3001,3010,3100,...... 从第二个,我们获得1122,1212,1221,2112,2121,2211,11022,11202 ...... 从第三个1111112,1111121,......
只有13,31,103,130,301,310是数字100 = 1×9² + 1×4² + 3×1²
...
100 = 1×8² + 1×6²
...
数字的可能平方和,因此只需要进一步调查。我们可以写
$ time ./problem92 7
8581146
real 0m0.010s
user 0m0.008s
sys 0m0.002s
$ time ./problem92 8
85744333
real 0m0.022s
user 0m0.018s
sys 0m0.003s
$ time ./problem92 9
854325192
real 0m0.040s
user 0m0.033s
sys 0m0.006s
$ time ./problem92 10
8507390852
real 0m0.074s
user 0m0.069s
sys 0m0.004s
这些分区中的第一个不生成子项,因为它需要五个非零数字,另一个明确给出生成两个子项68和86(如果限制为10 8 ,则为608,对于更大的限制,更多)),我们可以获得更好的缩放和更快的算法。
当我解决这个问题时,我写回来的相当未被优化的程序运行(输入是限制的10的指数)
{{1}}
在另一个联赛中。
答案 1 :(得分:8)
首先,我冒昧地清理你的代码:
endsAt89 1 = 0
endsAt89 89 = 1
endsAt89 n = endsAt89 (sumOfSquareDigits n)
sumOfSquareDigits 0 = 0
sumOfSquareDigits n = (n `mod` 10)^2 + sumOfSquareDigits (n `div` 10)
main = print . sum $ map endsAt89 [1..10^7]
在我糟糕的上网本上是1分13秒。让我们看看我们是否可以改善这一点。
由于数字很小,我们可以先使用机器大小的Int
而不是任意大小的Integer
。这只是添加类型签名的问题,例如
sumOfSquareDigits :: Int -> Int
这大大缩短了20秒的运行时间。
由于这些数字都是正数,我们可以将div
和mod
替换为稍快一点quot
和rem
,或者甚至同时使用{{1} }}:
quotRem
现在运行时间为17秒。让它的尾部递归剃掉另一秒:
sumOfSquareDigits :: Int -> Int
sumOfSquareDigits 0 = 0
sumOfSquareDigits n = r^2 + sumOfSquareDigits q
where (q, r) = quotRem x 10
为了进一步改进,我们可以注意到sumOfSquareDigits :: Int -> Int
sumOfSquareDigits n = loop n 0
where
loop 0 !s = s
loop n !s = loop q (s + r^2)
where (q, r) = quotRem n 10
对于给定的输入数字最多返回sumOfSquareDigits
,因此我们可以记住小数字以减少所需的迭代次数。这是我的最终版本(使用data-memocombinators包进行记忆):
567 = 7 * 9^2
在我的机器上运行时间不到9秒。