在matplotlib中定义色彩映射的中点

时间:2011-09-13 15:04:37

标签: python matplotlib

我想设置色彩图的中间点,即我的数据从-5到10,我想零为中间。我认为这样做的方法是继承normalize和使用规范,但我没有找到任何例子,我不清楚,我究竟要实现什么。

9 个答案:

答案 0 :(得分:74)

我知道这是游戏的后期,但我刚刚完成了这个过程并提出了一个解决方案,它可能不如子类化规范化那么强大,但更简单。我认为在这里为后人分享它会很好。

功能

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid

def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero.

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower offset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax / (vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highest point in the colormap's range.
          Defaults to 1.0 (no upper offset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap

一个例子

biased_data = np.random.random_integers(low=-15, high=5, size=(37,37))

orig_cmap = matplotlib.cm.coolwarm
shifted_cmap = shiftedColorMap(orig_cmap, midpoint=0.75, name='shifted')
shrunk_cmap = shiftedColorMap(orig_cmap, start=0.15, midpoint=0.75, stop=0.85, name='shrunk')

fig = plt.figure(figsize=(6,6))
grid = AxesGrid(fig, 111, nrows_ncols=(2, 2), axes_pad=0.5,
                label_mode="1", share_all=True,
                cbar_location="right", cbar_mode="each",
                cbar_size="7%", cbar_pad="2%")

# normal cmap
im0 = grid[0].imshow(biased_data, interpolation="none", cmap=orig_cmap)
grid.cbar_axes[0].colorbar(im0)
grid[0].set_title('Default behavior (hard to see bias)', fontsize=8)

im1 = grid[1].imshow(biased_data, interpolation="none", cmap=orig_cmap, vmax=15, vmin=-15)
grid.cbar_axes[1].colorbar(im1)
grid[1].set_title('Centered zero manually,\nbut lost upper end of dynamic range', fontsize=8)

im2 = grid[2].imshow(biased_data, interpolation="none", cmap=shifted_cmap)
grid.cbar_axes[2].colorbar(im2)
grid[2].set_title('Recentered cmap with function', fontsize=8)

im3 = grid[3].imshow(biased_data, interpolation="none", cmap=shrunk_cmap)
grid.cbar_axes[3].colorbar(im3)
grid[3].set_title('Recentered cmap with function\nand shrunk range', fontsize=8)

for ax in grid:
    ax.set_yticks([])
    ax.set_xticks([])

示例结果:

enter image description here

答案 1 :(得分:18)

这是一个子类化Normalize的解决方案。使用它

norm = MidPointNorm(midpoint=3)
imshow(X, norm=norm)

这是班级:

from numpy import ma
from matplotlib import cbook
from matplotlib.colors import Normalize

class MidPointNorm(Normalize):    
    def __init__(self, midpoint=0, vmin=None, vmax=None, clip=False):
        Normalize.__init__(self,vmin, vmax, clip)
        self.midpoint = midpoint

    def __call__(self, value, clip=None):
        if clip is None:
            clip = self.clip

        result, is_scalar = self.process_value(value)

        self.autoscale_None(result)
        vmin, vmax, midpoint = self.vmin, self.vmax, self.midpoint

        if not (vmin < midpoint < vmax):
            raise ValueError("midpoint must be between maxvalue and minvalue.")       
        elif vmin == vmax:
            result.fill(0) # Or should it be all masked? Or 0.5?
        elif vmin > vmax:
            raise ValueError("maxvalue must be bigger than minvalue")
        else:
            vmin = float(vmin)
            vmax = float(vmax)
            if clip:
                mask = ma.getmask(result)
                result = ma.array(np.clip(result.filled(vmax), vmin, vmax),
                                  mask=mask)

            # ma division is very slow; we can take a shortcut
            resdat = result.data

            #First scale to -1 to 1 range, than to from 0 to 1.
            resdat -= midpoint            
            resdat[resdat>0] /= abs(vmax - midpoint)            
            resdat[resdat<0] /= abs(vmin - midpoint)

            resdat /= 2.
            resdat += 0.5
            result = ma.array(resdat, mask=result.mask, copy=False)                

        if is_scalar:
            result = result[0]            
        return result

    def inverse(self, value):
        if not self.scaled():
            raise ValueError("Not invertible until scaled")
        vmin, vmax, midpoint = self.vmin, self.vmax, self.midpoint

        if cbook.iterable(value):
            val = ma.asarray(value)
            val = 2 * (val-0.5)  
            val[val>0]  *= abs(vmax - midpoint)
            val[val<0] *= abs(vmin - midpoint)
            val += midpoint
            return val
        else:
            val = 2 * (val - 0.5)
            if val < 0: 
                return  val*abs(vmin-midpoint) + midpoint
            else:
                return  val*abs(vmax-midpoint) + midpoint

答案 2 :(得分:16)

最简单的方法是使用vmin的{​​{1}}和vmax参数(假设您正在处理图像数据)而不是继承imshow

E.g。

matplotlib.colors.Normalize

enter image description here

答案 3 :(得分:5)

不确定您是否还在寻找答案。对我来说,尝试子类Normalize是不成功的。所以我专注于手动创建一个新的数据集,刻度和刻度标签,以获得我认为你的目标。

我发现matplotlib中的scale模块有一个用于通过'syslog'规则转换线图的类,所以我用它来转换数据。然后我缩放数据,使其从0变为1(Normalize通常做什么),但我将正数与负数不同。这是因为你的vmax和vmin可能不一样,所以.5 - &gt; 1可能覆盖比.5 - >更大的正范围。 0,负范围确实如此。我更容易创建一个例程来计算刻度和标签值。

下面是代码和示例图。

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mpl as mpl
import matplotlib.scale as scale

NDATA = 50
VMAX=10
VMIN=-5
LINTHRESH=1e-4

def makeTickLables(vmin,vmax,linthresh):
    """
    make two lists, one for the tick positions, and one for the labels
    at those positions. The number and placement of positive labels is 
    different from the negative labels.
    """
    nvpos = int(np.log10(vmax))-int(np.log10(linthresh))
    nvneg = int(np.log10(np.abs(vmin)))-int(np.log10(linthresh))+1
    ticks = []
    labels = []
    lavmin = (np.log10(np.abs(vmin)))
    lvmax = (np.log10(np.abs(vmax)))
    llinthres = int(np.log10(linthresh))
    # f(x) = mx+b
    # f(llinthres) = .5
    # f(lavmin) = 0
    m = .5/float(llinthres-lavmin)
    b = (.5-llinthres*m-lavmin*m)/2
    for itick in range(nvneg):
        labels.append(-1*float(pow(10,itick+llinthres)))
        ticks.append((b+(itick+llinthres)*m))
    # add vmin tick
    labels.append(vmin)
    ticks.append(b+(lavmin)*m)

    # f(x) = mx+b
    # f(llinthres) = .5
    # f(lvmax) = 1
    m = .5/float(lvmax-llinthres)
    b = m*(lvmax-2*llinthres) 
    for itick in range(1,nvpos):
        labels.append(float(pow(10,itick+llinthres)))
        ticks.append((b+(itick+llinthres)*m))
    # add vmax tick
    labels.append(vmax)
    ticks.append(b+(lvmax)*m)

    return ticks,labels


data = (VMAX-VMIN)*np.random.random((NDATA,NDATA))+VMIN

# define a scaler object that can transform to 'symlog'
scaler = scale.SymmetricalLogScale.SymmetricalLogTransform(10,LINTHRESH)
datas = scaler.transform(data)

# scale datas so that 0 is at .5
# so two seperate scales, one for positive and one for negative
data2 = np.where(np.greater(data,0),
                 .75+.25*datas/np.log10(VMAX),
                 .25+.25*(datas)/np.log10(np.abs(VMIN))
                 )

ticks,labels=makeTickLables(VMIN,VMAX,LINTHRESH)

cmap = mpl.cm.jet
fig = plt.figure()
ax = fig.add_subplot(111)
im = ax.imshow(data2,cmap=cmap,vmin=0,vmax=1)
cbar = plt.colorbar(im,ticks=ticks)
cbar.ax.set_yticklabels(labels)

fig.savefig('twoscales.png')

vmax=10,vmin=-5 and linthresh=1e-4

随意调整脚本顶部的“常量”(例如VMAX)以确认其表现良好。

答案 4 :(得分:5)

此解决方案的灵感来自this page

中具有相同名称的类

这里我创建了一个Normalize的子类,后跟一个最小的例子。

import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt


class MidpointNormalize(mpl.colors.Normalize):
    def __init__(self, vmin, vmax, midpoint=0, clip=False):
        self.midpoint = midpoint
        mpl.colors.Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        normalized_min = max(0, 1 / 2 * (1 - abs((self.midpoint - self.vmin) / (self.midpoint - self.vmax))))
        normalized_max = min(1, 1 / 2 * (1 + abs((self.vmax - self.midpoint) / (self.midpoint - self.vmin))))
        normalized_mid = 0.5
        x, y = [self.vmin, self.midpoint, self.vmax], [normalized_min, normalized_mid, normalized_max]
        return sp.ma.masked_array(sp.interp(value, x, y))


vals = sp.array([[-5., 0], [5, 10]]) 
vmin = vals.min()
vmax = vals.max()

norm = MidpointNormalize(vmin=vmin, vmax=vmax, midpoint=0)
cmap = 'RdBu_r' 

plt.imshow(vals, cmap=cmap, norm=norm)
plt.colorbar()
plt.show()

结果:pic-1

同样的例子只包含正面数据vals = sp.array([[1., 3], [6, 10]])

pic-2

总结一下 - 这个规范有以下属性:

  • 中点会变成中间色。
  • 上下范围将以相同的方式重新缩放,因此使用正确的色彩图时,色彩饱和度将对应于距中点的距离。
  • 颜色条仅显示图片上显示的颜色。
  • 即使vmin大于midpoint(似乎没有测试所有边缘情况),似乎也能正常工作。

答案 5 :(得分:3)

我使用了Paul H的优秀答案,但遇到了一个问题,因为我的一些数据范围从负到正,而其他数据范围从0到正数或从负数到0;在任何一种情况下,我希望0被着色为白色(我正在使用的色彩映射的中点)。使用现有实现,如果midpoint值等于1或0,则原始映射不会被覆盖。您可以在下图中看到: graphs before edit 第3列看起来正确,但第2列中的深蓝色区域和其余列中的深红色区域都应该是白色的(它们的数据值实际上是0)。使用我的修复程序给了我: graphs after edit 我的函数与Paul H的函数基本相同,我在for循环的开头编辑:

    def shiftedColorMap(cmap, min_val, max_val, name):
    '''Function to offset the "center" of a colormap. Useful for data with a negative min and positive max and you want the middle of the colormap's dynamic range to be at zero. Adapted from https://stackoverflow.com/questions/7404116/defining-the-midpoint-of-a-colormap-in-matplotlib

    Input
    -----
      cmap : The matplotlib colormap to be altered.
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower ofset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax/(vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highets point in the colormap's range.
          Defaults to 1.0 (no upper ofset). Should be between
          `midpoint` and 1.0.'''
    epsilon = 0.001
    start, stop = 0.0, 1.0
    min_val, max_val = min(0.0, min_val), max(0.0, max_val) # Edit #2
    midpoint = 1.0 - max_val/(max_val + abs(min_val))
    cdict = {'red': [], 'green': [], 'blue': [], 'alpha': []}
    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)
    # shifted index to match the data
    shift_index = np.hstack([np.linspace(0.0, midpoint, 128, endpoint=False), np.linspace(midpoint, 1.0, 129, endpoint=True)])
    for ri, si in zip(reg_index, shift_index):
        if abs(si - midpoint) < epsilon:
            r, g, b, a = cmap(0.5) # 0.5 = original midpoint.
        else:
            r, g, b, a = cmap(ri)
        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))
    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)
    return newcmap

编辑:当我的一些数据从一个小的正值到一个更大的正值时,我再次遇到了类似的问题,其中非常低的值是红色而不是白色。我通过在上面的代码中添加行Edit #2来修复它。

答案 6 :(得分:1)

如果你不介意计算vmin,vmax和0之间的比例,这是一个非常基本的线性地图,从蓝色到白色到红色,根据比例z设置白色:

def colormap(z):
    """custom colourmap for map plots"""

    cdict1 = {'red': ((0.0, 0.0, 0.0),
                      (z,   1.0, 1.0),
                      (1.0, 1.0, 1.0)),
              'green': ((0.0, 0.0, 0.0),
                        (z,   1.0, 1.0),
                        (1.0, 0.0, 0.0)),
              'blue': ((0.0, 1.0, 1.0),
                       (z,   1.0, 1.0),
                       (1.0, 0.0, 0.0))
              }

    return LinearSegmentedColormap('BlueRed1', cdict1)

cdict格式非常简单:行是渐变中创建的点:第一个条目是x值(沿着渐变的比率从0到1),第二个是前一个的结束值段,第三个是下一个段的起始值 - 如果你想要平滑的渐变,后两个总是相同的。 See the docs了解更多细节。

答案 7 :(得分:1)

请注意,在matplotlib版本3.1中,添加了DivergingNorm类。我认为它涵盖了您的用例。 可以这样使用:

from matplotlib import colors
colors.DivergingNorm(vmin=-4000., vcenter=0., vmax=10000)

答案 8 :(得分:0)

我遇到了类似的问题,但我希望最高值为全红色,并切断低蓝色值,使其看起来基本上像彩色底部被切掉了。这对我有用(包括可选的透明度):

def shift_zero_bwr_colormap(z: float, transparent: bool = True):
    """shifted bwr colormap"""
    if (z < 0) or (z > 1):
        raise ValueError('z must be between 0 and 1')

    cdict1 = {'red': ((0.0, max(-2*z+1, 0), max(-2*z+1, 0)),
                      (z,   1.0, 1.0),
                      (1.0, 1.0, 1.0)),

              'green': ((0.0, max(-2*z+1, 0), max(-2*z+1, 0)),
                        (z,   1.0, 1.0),
                        (1.0, max(2*z-1,0),  max(2*z-1,0))),

              'blue': ((0.0, 1.0, 1.0),
                       (z,   1.0, 1.0),
                       (1.0, max(2*z-1,0), max(2*z-1,0))),
              }
    if transparent:
        cdict1['alpha'] = ((0.0, 1-max(-2*z+1, 0), 1-max(-2*z+1, 0)),
                           (z,   0.0, 0.0),
                           (1.0, 1-max(2*z-1,0),  1-max(2*z-1,0)))

    return LinearSegmentedColormap('shifted_rwb', cdict1)

cmap =  shift_zero_bwr_colormap(.3)

x = np.arange(0, np.pi, 0.1)
y = np.arange(0, 2*np.pi, 0.1)
X, Y = np.meshgrid(x, y)
Z = np.cos(X) * np.sin(Y) * 5 + 5
plt.plot([0, 10*np.pi], [0, 20*np.pi], color='c', lw=20, zorder=-3)
plt.imshow(Z, interpolation='nearest', origin='lower', cmap=cmap)
plt.colorbar()