如何可靠地缩放大数据集的matplotlib pcolormesh图?

时间:2016-06-27 10:31:05

标签: python python-3.x matplotlib

我正在尝试使用pcolormesh中的matplotlib.pyplot绘制一些数据,但在保存输出时(特别是在适当缩放图像时)我遇到了一些困难。

我正在使用Python v3.4和matplotlib v1.51,如果这有所不同。

这就是我的代码目前的样子:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

def GetData(data_entries, num_of_channels):
    data_dict = {'timestamp' : np.linspace(1, data_entries*21, data_entries, endpoint=True)}

    for chan in range(0, num_of_channels, 1):
        data_dict['random%03d'%chan] = np.random.rand(data_entries, 1).flatten()
        num_at_each_end_to_highlight = 10
        data_dict['random%03d'%chan][0:num_at_each_end_to_highlight] = 1.5
        data_dict['random%03d'%chan][-num_at_each_end_to_highlight:] = 1.5

    for chan in range(0, num_of_channels, 1):
        data_dict['periodic%03d' % chan] = np.zeros(data_entries)#.flatten()
        data_dict['periodic%03d' % chan][::65] = 5000

    return pd.DataFrame(data_dict)

def GetSubPlotIndex(totalRows, totalCols, row):
    return totalRows*100+totalCols*10+row

def PlotData(df, num_of_channels, field_names):
    # Calculate the range of data to plot
    data_entries = len(df.index)

    # Create the x/y mesh that the data will be plotted on
    x = df['timestamp']
    y = np.linspace(0, num_of_channels - 1, num_of_channels)
    X,Y = np.meshgrid(x,y)

    # Iterate through all of the field types and produce one plot for each but share the X axis
    for idx, field_name in enumerate(field_names):

        # Create this sub-plot
        subPlotIndex = GetSubPlotIndex(len(field_names), 1, idx + 1)
        ax = plt.subplot(subPlotIndex)
        if idx is 0:
            ax.set_title('Raw Data Time Series')
        # Set the axis scale to exactly meet the limits of the data set.
        ax.set_autoscale_on(False)
        plt.axis([x[0], x[data_entries-1], 0, num_of_channels - 1])
        # Set up the colour palette used to render the data.
        # Make bad results (those that are masked) invisible so the background shows instead.
        palette = plt.cm.get_cmap('autumn')
        palette.set_bad(alpha=0.0)
        ax.set_axis_bgcolor('black') # Set the background to zero

        # Grab the data and transpose it so we can stick it in the time series running along the X axis.
        firstFftCol = df.columns.get_loc(field_name + "%03d"%(0))
        lastFftCol = df.columns.get_loc(field_name + "%03d"%(num_of_channels - 1))
        data = df.ix[:,firstFftCol:lastFftCol]
        data = data.T # Transpose so that time runs along the X axis and bin index is on the Y

        # Mask off data with zero's so that it doesn't obscure the data we're actually interested in.
        data = np.ma.masked_where(data == 0.0, data)

        # Actually create the data mesh so we can plot it
        z_min, z_max = data.min().min(), data.max().max()
        p = ax.pcolormesh(X,Y, data, cmap=palette, vmin=z_min, vmax=z_max)

        # Render it
        plt.plot()

        # Label the plot and add a key
        plt.ylabel(field_name)
        plt.colorbar(p)

    # Label the plot
    plt.xlabel('Time (ms)')

    # Record the result
    plt.savefig('test.png', edgecolor='none', transparent=False)

if __name__ == '__main__':
    data_entries = 30000 # Large values here cause issues
    num_of_channels = 255

    fields_to_plot = ('random', 'periodic')

    data = GetData(data_entries, num_of_channels)

    width_in_pixels = len(data.index)+200
    additional_vertical_space_per_plot = 50
    num_of_plots = len(fields_to_plot)
    height_in_pixels = (num_of_channels+additional_vertical_space_per_plot)*num_of_plots
    dpi = 80 # The default according to the documentation.
    fig = plt.figure(1,figsize=(width_in_pixels/dpi, height_in_pixels/dpi), dpi=dpi)

    PlotData(data, num_of_channels, fields_to_plot)

有1000个条目,结果看起来很好:

enter image description here

如果我将样本数增加到我想要绘制的大小(30000),图像的大小正确(宽度为30200像素),但我看到了很多死空间。这是我看到的问题的缩小摘要:

enter image description here

有没有办法用数据更准确地填充图像?

1 个答案:

答案 0 :(得分:0)

感谢@Dusch的提示,这似乎解决了一些问题:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

def GetData(data_entries, num_of_channels):
    data_dict = {'timestamp' : np.linspace(1, data_entries*21, data_entries, endpoint=True)}

    for chan in range(0, num_of_channels, 1):
        data_dict['random%03d'%chan] = np.random.rand(data_entries, 1).flatten()
        num_at_each_end_to_highlight = 10
        data_dict['random%03d'%chan][0:num_at_each_end_to_highlight] = 1.5
        data_dict['random%03d'%chan][-num_at_each_end_to_highlight:] = 1.5

    for chan in range(0, num_of_channels, 1):
        data_dict['periodic%03d' % chan] = np.zeros(data_entries)#.flatten()
        data_dict['periodic%03d' % chan][::65] = 5000

    return pd.DataFrame(data_dict)

def GetSubPlotIndex(totalRows, totalCols, row):
    return totalRows*100+totalCols*10+row

def PlotData(df, num_of_channels, field_names):
    # Calculate the range of data to plot
    data_entries = len(df.index)

    # Create the x/y mesh that the data will be plotted on
    x = df['timestamp']
    y = np.linspace(0, num_of_channels - 1, num_of_channels)
    X,Y = np.meshgrid(x,y)

    # Iterate through all of the field types and produce one plot for each but share the X axis
    for idx, field_name in enumerate(field_names):

        # Create this sub-plot
        subPlotIndex = GetSubPlotIndex(len(field_names), 1, idx + 1)
        ax = plt.subplot(subPlotIndex)
        if idx is 0:
            ax.set_title('Raw Data Time Series')
        # Set the axis scale to exactly meet the limits of the data set.
        ax.set_autoscale_on(False)
        plt.axis([x[0], x[data_entries-1], 0, num_of_channels - 1])
        # Set up the colour palette used to render the data.
        # Make bad results (those that are masked) invisible so the background shows instead.
        palette = plt.cm.get_cmap('autumn')
        palette.set_bad(alpha=0.0)
        ax.set_axis_bgcolor('black') # Set the background to zero

        # Grab the data and transpose it so we can stick it in the time series running along the X axis.
        firstFftCol = df.columns.get_loc(field_name + "%03d"%(0))
        lastFftCol = df.columns.get_loc(field_name + "%03d"%(num_of_channels - 1))
        data = df.ix[:,firstFftCol:lastFftCol]
        data = data.T # Transpose so that time runs along the X axis and bin index is on the Y

        # Mask off data with zero's so that it doesn't obscure the data we're actually interested in.
        data = np.ma.masked_where(data == 0.0, data)

        # Actually create the data mesh so we can plot it
        z_min, z_max = data.min().min(), data.max().max()
        p = ax.pcolormesh(X,Y, data, cmap=palette, vmin=z_min, vmax=z_max)

        # Render it
        plt.plot()

        # Label this sub-plot
        plt.ylabel(field_name)

        # Sort out the color bar
        fig = plt.gcf()
        image_width = fig.get_size_inches()[0] * fig.dpi  # size in pixels
        colorbar_padding_width_in_pixels = 20
        colorbar_padding = colorbar_padding_width_in_pixels/image_width
        plt.colorbar(p, pad=colorbar_padding)

    # Label the plot
    plt.xlabel('Time (ms)')

    # Record the result
    plt.savefig('test.png', edgecolor='none', transparent=False, bbox_inches='tight')

    plt.tight_layout()

if __name__ == '__main__':
    data_entries = 30000 # Large values here cause issues
    num_of_channels = 255

    fields_to_plot = ('random', 'periodic')

    data = GetData(data_entries, num_of_channels)

    width_in_pixels = len(data.index)+200
    additional_vertical_space_per_plot = 50
    num_of_plots = len(fields_to_plot)
    height_in_pixels = (num_of_channels+additional_vertical_space_per_plot)*num_of_plots
    dpi = 80 # The default according to the documentation.
    fig = plt.figure(1,figsize=(width_in_pixels/dpi, height_in_pixels/dpi), dpi=dpi)

    PlotData(data, num_of_channels, fields_to_plot)

最后的秘诀是:

  1. plt.tight_layout()来电之前立即添加plt.savefig
  2. bbox_inches='tight'添加到plt.savefig来电。
  3. 通过检查20像素填充等于的整体图像宽度的比例,在计算, pad=colorbar_padding后添加colorbar_padding