在Matplotlib中的单个窗口中显示多个3d条形图而不会变形

时间:2018-04-17 12:54:27

标签: python python-3.x matplotlib

我试图在一个窗口中显示多个3D条形图。我遇到的问题是这些图的变形。具体来说,它们似乎以不同的角度和不同的尺度绘制。

以下是我的代码片段:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np

cols = ['Seller', 'Section', 'Store', 'Total Pieces:  Jan', 'Total Pieces:  Feb']

mainList = [['BonM', 'Butterfly', 'E-shop', '5250', '1400'],
            ['BonM', 'Butterfly', 'FL', '1085', '2450'],
            ['BonM', 'Butterfly', 'Holesovice', '1085', '2450'],
            ['Kla', 'Decorations', 'Holesovice', '105', '35'],
            ['LezDar', 'YRings', 'Holesovice', '0', '136'],
            ['LezDar', 'OtherB', 'E-shop', '0', '160'],
            ['LezDar', 'OtherB', 'Other', '0', '112'],
            ['PP', 'Skirt', 'FL', '3156', '1380'],
            ['PP', 'Skirt', 'Holesovice', '1320', '0'],
            ['PP', 'Skirt', 'SAS', '450', '0'],
            ['PP', 'Bags', 'E-shop', '250', '0'],
            ['PP', 'Skirt', 'E-shop', '5600', '0'],
            ['PP', 'Dress', 'Other', '6551', '3018'],
            ['Mar', 'Dress', 'Holesovice', '5000', '4000']]

loopVals = [x for x in cols if 'Total' in x]
leftOver = [x for x in cols if x not in loopVals]

fig = plt.figure()

for i, j in enumerate(loopVals):

    plotName = loopVals[i]
    refinedList = []
    for index, row in enumerate(mainList):
        if mainList[index][3 + i] != '0':
            refinedList.append(row)

    ax = fig.add_subplot(1 + i,1,1, projection = "3d")

    distinct_sellers = sorted(list(set([x[0] for x in refinedList])))
    distinct_stores = sorted(list(set([x[2] for x in refinedList])))
    all_items = [x[1] for x in refinedList]

    ax.set_xlim3d(0,2*len(distinct_sellers))
    ax.set_ylim3d(0,2*len(distinct_stores))

    xList = [i + (i + 1) for i, x in enumerate(distinct_sellers)]
    yList = [i + (i + 1) for i, x in enumerate(distinct_stores)]

    x_dict = dict(zip(distinct_sellers, xList))
    y_dict = dict(zip(distinct_stores, yList))

    xpos = [x_dict[x[0]] for x in refinedList]
    ypos = [y_dict[x[2]] for x in refinedList]
    zpos = [0 for i in range(len(xpos))]
    object_list = list(zip(xpos, ypos))

    z_list = []
    for index, tup in enumerate(object_list):
        emptyList = []
        emptyList.append(list(x_dict.keys())[list(x_dict.values()).index(tup[0])])
        emptyList.append(list(y_dict.keys())[list(y_dict.values()).index(tup[1])])
        z_list.append(emptyList)

    _rows = []
    object_names = []
    for ind, element in enumerate(all_items):
        _elements = []
        object_names.append(element)
        for index, item in enumerate(z_list):
            if item[0] in refinedList[index] and item[1] in refinedList[index] and element in refinedList[index]:
                _elements.append(int(refinedList[index][3 + i]))
            else:
                _elements.append(0)
        _rows.append(_elements)
    res = [list(tup) for tup in {tuple(item) for item in _rows}]
    pdz = [x for x in res if np.any(x)]

    tL = []
    hL = []
    for index, item in enumerate(object_names):
        if item not in hL:
            dic = dict(zip([item], [_rows[index]]))
            tL.append(dic)
        hL.append(item)

    dict_res = {k: v for d in tL for k, v in d.items() if np.any(v)}
    true_name = [list(dict_res.keys())[list(dict_res.values()).index(i)] for i in pdz]

    xm = [x + 0.5 for x in xpos]
    xmM = list(set(xm))
    ym = [y + 0.5 for y in ypos]
    ymM = list(set(ym))

    dx = [1 for i in range(len(xpos))]
    dy = [1 for i in range(len(ypos))]

    plt.xticks(xmM, distinct_sellers)
    plt.yticks(ymM, distinct_stores)

    objectList = []
    for i in range(len(pdz)):
        col = np.random.rand(3,)
        objectList.append(plt.Rectangle((0, 0), 1, 1, fc=col))
        ax.bar3d(xpos, ypos, zpos, dx, dy, np.asarray(pdz[i]), color=col)
        zpos += np.asarray(pdz[i])

    newList = []
    for index, element in enumerate(xpos):
        newList.append([xpos[index], ypos[index], zpos[index]])

    for index, element in enumerate(newList):
        ax.text(newList[index][0] + 0.5, newList[index][1] + 0.5, newList[index][2], newList[index][2], horizontalalignment='center', verticalalignment='bottom', color = 'grey')

    ax.legend(objectList, true_name)
    ax.text2D(0.05, 0.95, plotName, transform=ax.transAxes)

plt.gca().invert_xaxis()
plt.show()

这是显示上述变形的图片:

enter image description here

有没有办法防止以不同的角度/比例显示。我还想防止的另一个问题是"背后的重叠" bar(即前栏不可读,因为后栏显示在其上)。

我将不胜感激任何提示/建议。

1 个答案:

答案 0 :(得分:0)

我决定使用PYQT5库来显示图表,它工作得很好(不再有图形变形):

import sys
from PyQt5.QtWidgets import *
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

class MainW(QMainWindow):
    def __init__(self, parent=None):
        super(MainW, self).__init__(parent)

        cols = ['Seller', 'Section', 'Store', 'Total Pieces:  Jan', 'Total Pieces:  Feb']

        mainList = [['BonM', 'Butterfly', 'E-shop', '5250', '1400'],
                    ['BonM', 'Butterfly', 'FL', '1085', '2450'],
                    ['BonM', 'Butterfly', 'Holesovice', '1085', '2450'],
                    ['Kla', 'Decorations', 'Holesovice', '105', '35'],
                    ['LezDar', 'YRings', 'Holesovice', '0', '136'],
                    ['LezDar', 'OtherB', 'E-shop', '0', '160'],
                    ['LezDar', 'OtherB', 'Other', '0', '112'],
                    ['PP', 'Skirt', 'FL', '3156', '1380'],
                    ['PP', 'Skirt', 'Holesovice', '1320', '0'],
                    ['PP', 'Skirt', 'SAS', '450', '0'],
                    ['PP', 'Bags', 'E-shop', '250', '0'],
                    ['PP', 'Skirt', 'E-shop', '5600', '0'],
                    ['PP', 'Dress', 'Other', '6551', '3018'],
                    ['Mar', 'Dress', 'Holesovice', '5000', '4000']]

        loopVals = [x for x in cols if 'Total' in x]
        leftOver = [x for x in cols if x not in loopVals]

        widget = QWidget()
        layout = QVBoxLayout()
        widget.setLayout(layout)
        self.setCentralWidget(widget)

        for i, j in enumerate(loopVals):
            self.fig = plt.figure()
            self.canvas = FigureCanvas(self.fig)
            plotName = loopVals[i]
            layout.addWidget(self.canvas)

            refinedList = []
            for index, row in enumerate(mainList):
                if mainList[index][3 + i] != '0':
                    refinedList.append(row)

            ax = self.fig.add_subplot(1, 1, 1, projection = "3d")

            distinct_sellers = sorted(list(set([x[0] for x in refinedList])))
            distinct_stores = sorted(list(set([x[2] for x in refinedList])))
            all_items = [x[1] for x in refinedList]

            ax.set_xlim3d(0,2*len(distinct_sellers))
            ax.set_ylim3d(0,2*len(distinct_stores))

            xList = [i + (i + 1) for i, x in enumerate(distinct_sellers)]
            yList = [i + (i + 1) for i, x in enumerate(distinct_stores)]

            x_dict = dict(zip(distinct_sellers, xList))
            y_dict = dict(zip(distinct_stores, yList))

            xpos = [x_dict[x[0]] for x in refinedList]
            ypos = [y_dict[x[2]] for x in refinedList]
            zpos = [0 for i in range(len(xpos))]
            object_list = list(zip(xpos, ypos))

            z_list = []
            for index, tup in enumerate(object_list):
                emptyList = []
                emptyList.append(list(x_dict.keys())[list(x_dict.values()).index(tup[0])])
                emptyList.append(list(y_dict.keys())[list(y_dict.values()).index(tup[1])])
                z_list.append(emptyList)

            _rows = []
            object_names = []
            for ind, element in enumerate(all_items):
                _elements = []
                object_names.append(element)
                for index, item in enumerate(z_list):
                    if item[0] in refinedList[index] and item[1] in refinedList[index] and element in refinedList[index]:
                        _elements.append(int(refinedList[index][3 + i]))
                    else:
                        _elements.append(0)
                _rows.append(_elements)
            res = [list(tup) for tup in {tuple(item) for item in _rows}]
            pdz = [x for x in res if np.any(x)]

            tL = []
            hL = []
            for index, item in enumerate(object_names):
                if item not in hL:
                    dic = dict(zip([item], [_rows[index]]))
                    tL.append(dic)
                hL.append(item)

            dict_res = {k: v for d in tL for k, v in d.items() if np.any(v)}
            true_name = [list(dict_res.keys())[list(dict_res.values()).index(i)] for i in pdz]

            xm = [x + 0.5 for x in xpos]
            xmM = list(set(xm))
            ym = [y + 0.5 for y in ypos]
            ymM = list(set(ym))

            dx = [1 for i in range(len(xpos))]
            dy = [1 for i in range(len(ypos))]

            plt.xticks(xmM, distinct_sellers)
            plt.yticks(ymM, distinct_stores)

            objectList = []
            for i in range(len(pdz)):
                col = np.random.rand(3,)
                objectList.append(plt.Rectangle((0, 0), 1, 1, fc=col))
                ax.bar3d(xpos, ypos, zpos, dx, dy, np.asarray(pdz[i]), color=col)
                zpos += np.asarray(pdz[i])

            newList = []
            for index, element in enumerate(xpos):
                newList.append([xpos[index], ypos[index], zpos[index]])

            for index, element in enumerate(newList):
                ax.text(newList[index][0] + 0.5, newList[index][1] + 0.5, newList[index][2], newList[index][2], horizontalalignment='center', verticalalignment='bottom', color = 'grey')

            ax.legend(objectList, true_name)
            ax.text2D(0.05, 0.95, plotName, transform=ax.transAxes)

        plt.gca().invert_xaxis()

if __name__ == '__main__':

    app = QApplication(sys.argv)
    main = MainW()
    main.show()
    sys.exit(app.exec_())