使用数组/其他方法进行优化

时间:2018-08-21 17:08:28

标签: python pandas csv

我有这段代码,用于本质上预测预测值/实际值(部门生产的单位),并且从我有100个标记为1-100的csv文件的文件夹中提取预测值。我告诉程序要先返回文件并将其与产生的实际值(在basefile / baseheader中给出)进行比较。基本上,每个文件都具有当月产生的实际单位,然后具有12个月的预测。我已经使用该程序进行了计算,但是运行速度很慢。当我指定的范围较大时,显示图形需要30秒钟以上,但是,如果我只是看到几个月的范围,它的运行速度很快。有人告诉我,我需要更改代码的一部分,即“合计”,“合计”,“三合计”等,并将它们更改为数组。但是,我不确定该怎么做,以后我也不知道如何绘制精确值(因为我需要先绘制一个合计,然后绘制两个合计,等等(然后显示一个合计,两个合计等的值分布) on)关于如何解决此问题以加快程序速度的任何想法?

import csv
import pandas as pd
import matplotlib.pyplot as plt

# Beginning part of code is user input and calcultions to determine the Department,Range, and there is a function called getfileheader


def nmonthaccuracy(basefilenumber, n):
    basefileread = pd.read_csv(str(basefilenumber) + ".csv", encoding="Latin-1")
    baseheader = getfileheader(basefilenumber)
    basefilevalue = basefileread.loc[
        basefileread["Customer"].str.contains(Department, na=False), baseheader
    ]

    nmonthread = pd.read_csv(str(basefilenumber - n) + ".csv", encoding="Latin-1")
    nmonthvalue = nmonthread.loc[
        nmonthread["Customer"].str.contains(Department, na=False), baseheader
    ]

    return (
        (1 - (int(basefilevalue) / int(nmonthvalue)) + 1)
        if int(nmonthvalue) > int(basefilevalue)
        else int(nmonthvalue) / int(basefilevalue)
    )


N = 13
total = [0] * N
total_by_month_list = [[] for _ in range(N)]
for basefilenumber in range(int(StartRange), int(EndRange)):
    for n in range(N):
        total[n] += nmonthaccuracy(basefilenumber, n)
        total_by_month_list[n].append(nmonthaccuracy(basefilenumber, n))

onetotal=total[1]/ Range
twototal=total[2]/ Range
threetotal=total[3]/ Range
fourtotal=total[4]/ Range
fivetotal=total[5]/ Range
sixtotal=total[6]/ Range
seventotal=total[7]/ Range
eighttotal=total[8]/ Range
ninetotal=total[9]/ Range
tentotal=total[10]/ Range
eleventotal=total[11]/ Range
twelvetotal=total[12]/ Range
onetotallist=total_by_month_list[1]
twototallist=total_by_month_list[2]
threetotallist=total_by_month_list[3]
fourtotallist=total_by_month_list[4]
fivetotallist=total_by_month_list[5]
sixtotallist=total_by_month_list[6]
seventotallist=total_by_month_list[7]
eighttotallist=total_by_month_list[8]
ninetotallist=total_by_month_list[9]
tentotallist=total_by_month_list[10]
eleventotallist=total_by_month_list[11]
twelvetotallist=total_by_month_list[12]






x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
y = [
    onetotal,
    twototal,
    threetotal,
    fourtotal,
    fivetotal,
    sixtotal,
    seventotal,
    eighttotal,
    ninetotal,
    tentotal,
    eleventotal,
    twelvetotal,
]
z = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
w = [
    (onetotallist),
    (twototallist),
    (threetotallist),
    (fourtotallist),
    (fivetotallist),
    (sixtotallist),
    (seventotallist),
    (eighttotallist),
    (ninetotallist),
    (tentotallist),
    (eleventotallist),
    (twelvetotallist),
]


fig, ax = plt.subplots()
for ze, we in zip(z, w):
    plt.scatter([ze] * len(we), we, marker="D", s=5)


plt.xlabel("Number of months forecast")
plt.ylabel("Predicted/Actual ratio")
plt.title("Predicted to actual ratio for n month forecast")
ax.plot(x, y, label="Predicted/Actual")
for a, b in zip(x, y):
    plt.text(a, b, str(round(b, 2)))
plt.scatter(x, y)
plt.show()

这是一个虚拟的csv文件: enter image description here

还有我的图表(如果好奇的话) enter image description here

1 个答案:

答案 0 :(得分:1)

好的,有几件事:

  • 您正在读取每个文件24次(每N次两次)。
  • 生成x / y / z / w时不需要重复

我所做的优化/更改大致是:

  • 使用lru_cache装饰器确保每个文件仅读取一次(并保存在内存中;如果出现问题,则可以限制lru缓存的大小-请参阅文档)
  • 为列表使用从零开始的索引(0..11);这对于Python来说是惯用的。

请注意,这是干编码,因此可能存在一些错误或遗漏:)


import functools

import pandas as pd
import matplotlib.pyplot as plt

Department = ...
Range = ...
StartRange = ...
EndRange = ...


# The lru_cache decorator will ensure each file is read into memory only once (and kept there)
@functools.lru_cache()
def read_file(n):
    return pd.read_csv(str(n) + ".csv", encoding="Latin-1")


def nmonthaccuracy(basefilenumber, n):
    basefileread = read_file(basefilenumber)
    baseheader = getfileheader(basefilenumber)
    basefilevalue = basefileread.loc[
        basefileread["Customer"].str.contains(Department, na=False), baseheader
    ]

    nmonthread = read_file(basefilenumber - n)
    nmonthvalue = nmonthread.loc[
        nmonthread["Customer"].str.contains(Department, na=False), baseheader
    ]

    return (
        (1 - (int(basefilevalue) / int(nmonthvalue)) + 1)
        if int(nmonthvalue) > int(basefilevalue)
        else int(nmonthvalue) / int(basefilevalue)
    )


N = 12
total_by_month_list = [[] for _ in range(N)]
for basefilenumber in range(int(StartRange), int(EndRange)):
    for n in range(N):
        # note "n+1" below since n is now zero-indexed
        total_by_month_list[n].append(nmonthaccuracy(basefilenumber, n + 1))

total = [sum(by_month) for by_month in total_by_month_list]

x = list(range(1, N + 1))  # [1..12]
y = [t / Range for t in total]
z = x  # same as x
w = total_by_month_list

# ... matplotlib code ...