我正在尝试模拟抛硬币和获利并在matplotlib中绘制图形:
from random import choice
import matplotlib.pyplot as plt
import time
start_time = time.time()
num_of_graphs = 2000
tries = 2000
coins = [150, -100]
last_loss = 0
for a in range(num_of_graphs):
profit = 0
line = []
for i in range(tries):
profit = profit + choice(coins)
if (profit < 0 and last_loss < i):
last_loss = i
line.append(profit)
plt.plot(line)
plt.show()
print("--- %s seconds ---" % (time.time() - start_time))
print("No losses after " + str(last_loss) + " iterations")
最终结果是
--- 9.30498194695 seconds ---
No losses after 310 iterations
为什么运行该脚本需要这么长时间?如果将num_of_graphs
更改为10000,脚本将永远无法完成。
您将如何优化呢?
答案 0 :(得分:4)
您对执行时间的度量过于粗糙。以下内容使您可以测量模拟所需的时间,而无需进行绘图:
它正在使用numpy。
import matplotlib.pyplot as plt
import numpy as np
import time
def run_sims(num_sims, num_flips):
start = time.time()
sims = [np.random.choice(coins, num_flips).cumsum() for _ in range(num_sims)]
end = time.time()
print(f"sim time = {end-start}")
return sims
def plot_sims(sims):
start = time.time()
for line in sims:
plt.plot(line)
end = time.time()
print(f"plotting time = {end-start}")
plt.show()
if __name__ == '__main__':
start_time = time.time()
num_sims = 2000
num_flips = 2000
coins = np.array([150, -100])
plot_sims(run_sims(num_sims, num_flips))
sim time = 0.13962197303771973
plotting time = 6.621474981307983
您可以看到,sim时间大大减少了(在我的2011年笔记本电脑上,它的时间大约为7秒);绘制时间取决于matplotlib。
答案 1 :(得分:3)
matplotlib随着脚本的进行而变得越来越慢,因为它 重新绘制您先前绘制的所有线条-甚至 滚动到屏幕之外的那些。
这是西蒙·吉本斯(Simon Gibbons)先前回答的post的答案。
matplotlib并未针对速度进行优化,而是针对其图形进行了优化。以下是一些为提高速度而开发的链接:
您可以参考matplotlib cookbook了解有关性能的更多信息。
答案 2 :(得分:1)
为了更好地优化代码,我将始终尝试使用numpy进行矢量化来替换循环,或者根据我的具体需要,使用其他在后台使用numpy的库。
在这种情况下,您可以通过以下方式计算和绘制利润:
import matplotlib.pyplot as plt
import time
import numpy as np
start_time = time.time()
num_of_graphs = 2000
tries = 2000
coins = [150, -100]
# Create a 2-D array with random choices
# rows for tries, columns for individual runs (graphs).
coin_tosses = np.random.choice(coins, (tries, num_of_graphs))
# Caculate 2-D array of profits by summing
# cumulatively over rows (trials).
profits = coin_tosses.cumsum(axis=0)
# Plot everything in one shot.
plt.plot(profits)
plt.show()
print("--- %s seconds ---" % (time.time() - start_time))
在我的配置中,此代码带有aprox。运行6.3秒(6.2绘制),而代码则花费了将近15秒。