如何将Z刻度标签的颜色与散点颜色相匹配?

时间:2018-01-03 11:20:03

标签: python matplotlib scatter-plot

编辑: 这个问题通过使用循环来解决

for i, x in enumerate(plt.cm.jet(np.linspace(0,1,N))):
    plt.gca().get_zticklabels()[i].set_color(x)

其中plt.cm.jet(np.linspace(0,1,N))使用调色板

如何将Z刻度标签的颜色与散点颜色相匹配?

我通过合并这两个答案来制作自定义的3D轴:

How to color a specific gridline/tickline in 3D Matplotlib Scatter Plot figure?? 对于网格线颜色

https://stackoverflow.com/a/17927118/8881965用于覆盖_axinfo属性

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.axis3d import Axis
import matplotlib.pyplot as plt
import matplotlib.projections as proj
from matplotlib.colors import colorConverter

custom_AXINFO = {
    'x': {'i': 0, 'tickdir': 1, 'juggled': (1, 0, 2), #empty plane
          'color': (0, 0, 0)},
    'y': {'i': 1, 'tickdir': 0, 'juggled': (0, 1, 2),   #date plane
          'color': (0.756, 0.145, 0.184)},
    'z': {'i': 2, 'tickdir': 0, 'juggled': (0, 2, 1),
          'color': (0.835, 0.549, 0.164)},}    #bottom plane


class axis3d_custom(Axis):
    def __init__(self, adir, v_intervalx, d_intervalx, axes, *args, **kwargs):
        Axis.__init__(self, adir, v_intervalx, d_intervalx, axes, *args, **kwargs)
        self.gridline_colors = []
    def set_gridline_color(self, *gridline_info):
        '''Gridline_info is a tuple containing the value of the gridline to change
        and the color to change it to. A list of tuples may be used with the * operator.'''
        self.gridline_colors.extend(gridline_info)
    def draw(self, renderer):
        # filter locations here so that no extra grid lines are drawn
        Axis.draw(self, renderer)
        which_gridlines = []
        if self.gridline_colors:
            locmin, locmax = self.get_view_interval()
            if locmin > locmax:
                locmin, locmax = locmax, locmin

            # Rudimentary clipping
            majorLocs = [loc for loc in self.major.locator() if
                         locmin <= loc <= locmax]
            for i, val in enumerate(majorLocs):
                for colored_val, color in self.gridline_colors:
                    if val == colored_val:
                        which_gridlines.append((i, color))
            colors = self.gridlines.get_colors()
            for val, color in which_gridlines:
                colors[val] = colorConverter.to_rgba(color)
            self.gridlines.set_color(colors)
            self.gridlines.draw(renderer, project=True)

class XAxis(axis3d_custom):
    _AXINFO = custom_AXINFO
    def get_data_interval(self):
        'return the Interval instance for this axis data limits'
        return self.axes.xy_dataLim.intervalx

class YAxis(axis3d_custom):  
    _AXINFO = custom_AXINFO
    def get_data_interval(self):
        'return the Interval instance for this axis data limits'
        return self.axes.xy_dataLim.intervaly

class ZAxis(axis3d_custom):
    _AXINFO = custom_AXINFO
    def get_data_interval(self):
        'return the Interval instance for this axis data limits'
        return self.axes.zz_dataLim.intervalx

class Axes3D_custom(Axes3D):
    """
    3D axes object.
    """
    name = '3d_custom'

    def _init_axis(self):
        '''Init 3D axes; overrides creation of regular X/Y axes'''
        self.w_xaxis = XAxis('x', self.xy_viewLim.intervalx,
                            self.xy_dataLim.intervalx, self)
        self.xaxis = self.w_xaxis
        self.w_yaxis = YAxis('y', self.xy_viewLim.intervaly,
                            self.xy_dataLim.intervaly, self)
        self.yaxis = self.w_yaxis
        self.w_zaxis = ZAxis('z', self.zz_viewLim.intervalx,
                            self.zz_dataLim.intervalx, self)
        self.zaxis = self.w_zaxis

        for ax in self.xaxis, self.yaxis, self.zaxis:
            ax.init3d()
proj.projection_registry.register(Axes3D_custom)

bx.scatter(xs,ys,zs, zdir=zs,c=plt.cm.jet(np.linspace(0,1,N))) #scatter 3d graph

这就是图表的样子 enter image description here

谢谢!

1 个答案:

答案 0 :(得分:0)

通过使用循环

解决了这个问题
for i, x in enumerate(plt.cm.jet(np.linspace(0,1,N))):
    plt.gca().get_zticklabels()[i].set_color(x)

其中plt.cm.jet(np.linspace(0,1,N))是使用的调色板