我正在尝试使用pcolormesh
中的matplotlib.pyplot
绘制一些数据,但在保存输出时(特别是在适当缩放图像时)我遇到了一些困难。
我正在使用Python v3.4和matplotlib v1.51,如果这有所不同。
这就是我的代码目前的样子:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
def GetData(data_entries, num_of_channels):
data_dict = {'timestamp' : np.linspace(1, data_entries*21, data_entries, endpoint=True)}
for chan in range(0, num_of_channels, 1):
data_dict['random%03d'%chan] = np.random.rand(data_entries, 1).flatten()
num_at_each_end_to_highlight = 10
data_dict['random%03d'%chan][0:num_at_each_end_to_highlight] = 1.5
data_dict['random%03d'%chan][-num_at_each_end_to_highlight:] = 1.5
for chan in range(0, num_of_channels, 1):
data_dict['periodic%03d' % chan] = np.zeros(data_entries)#.flatten()
data_dict['periodic%03d' % chan][::65] = 5000
return pd.DataFrame(data_dict)
def GetSubPlotIndex(totalRows, totalCols, row):
return totalRows*100+totalCols*10+row
def PlotData(df, num_of_channels, field_names):
# Calculate the range of data to plot
data_entries = len(df.index)
# Create the x/y mesh that the data will be plotted on
x = df['timestamp']
y = np.linspace(0, num_of_channels - 1, num_of_channels)
X,Y = np.meshgrid(x,y)
# Iterate through all of the field types and produce one plot for each but share the X axis
for idx, field_name in enumerate(field_names):
# Create this sub-plot
subPlotIndex = GetSubPlotIndex(len(field_names), 1, idx + 1)
ax = plt.subplot(subPlotIndex)
if idx is 0:
ax.set_title('Raw Data Time Series')
# Set the axis scale to exactly meet the limits of the data set.
ax.set_autoscale_on(False)
plt.axis([x[0], x[data_entries-1], 0, num_of_channels - 1])
# Set up the colour palette used to render the data.
# Make bad results (those that are masked) invisible so the background shows instead.
palette = plt.cm.get_cmap('autumn')
palette.set_bad(alpha=0.0)
ax.set_axis_bgcolor('black') # Set the background to zero
# Grab the data and transpose it so we can stick it in the time series running along the X axis.
firstFftCol = df.columns.get_loc(field_name + "%03d"%(0))
lastFftCol = df.columns.get_loc(field_name + "%03d"%(num_of_channels - 1))
data = df.ix[:,firstFftCol:lastFftCol]
data = data.T # Transpose so that time runs along the X axis and bin index is on the Y
# Mask off data with zero's so that it doesn't obscure the data we're actually interested in.
data = np.ma.masked_where(data == 0.0, data)
# Actually create the data mesh so we can plot it
z_min, z_max = data.min().min(), data.max().max()
p = ax.pcolormesh(X,Y, data, cmap=palette, vmin=z_min, vmax=z_max)
# Render it
plt.plot()
# Label the plot and add a key
plt.ylabel(field_name)
plt.colorbar(p)
# Label the plot
plt.xlabel('Time (ms)')
# Record the result
plt.savefig('test.png', edgecolor='none', transparent=False)
if __name__ == '__main__':
data_entries = 30000 # Large values here cause issues
num_of_channels = 255
fields_to_plot = ('random', 'periodic')
data = GetData(data_entries, num_of_channels)
width_in_pixels = len(data.index)+200
additional_vertical_space_per_plot = 50
num_of_plots = len(fields_to_plot)
height_in_pixels = (num_of_channels+additional_vertical_space_per_plot)*num_of_plots
dpi = 80 # The default according to the documentation.
fig = plt.figure(1,figsize=(width_in_pixels/dpi, height_in_pixels/dpi), dpi=dpi)
PlotData(data, num_of_channels, fields_to_plot)
有1000个条目,结果看起来很好:
如果我将样本数增加到我想要绘制的大小(30000),图像的大小正确(宽度为30200像素),但我看到了很多死空间。这是我看到的问题的缩小摘要:
有没有办法用数据更准确地填充图像?
答案 0 :(得分:0)
感谢@Dusch的提示,这似乎解决了一些问题:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
def GetData(data_entries, num_of_channels):
data_dict = {'timestamp' : np.linspace(1, data_entries*21, data_entries, endpoint=True)}
for chan in range(0, num_of_channels, 1):
data_dict['random%03d'%chan] = np.random.rand(data_entries, 1).flatten()
num_at_each_end_to_highlight = 10
data_dict['random%03d'%chan][0:num_at_each_end_to_highlight] = 1.5
data_dict['random%03d'%chan][-num_at_each_end_to_highlight:] = 1.5
for chan in range(0, num_of_channels, 1):
data_dict['periodic%03d' % chan] = np.zeros(data_entries)#.flatten()
data_dict['periodic%03d' % chan][::65] = 5000
return pd.DataFrame(data_dict)
def GetSubPlotIndex(totalRows, totalCols, row):
return totalRows*100+totalCols*10+row
def PlotData(df, num_of_channels, field_names):
# Calculate the range of data to plot
data_entries = len(df.index)
# Create the x/y mesh that the data will be plotted on
x = df['timestamp']
y = np.linspace(0, num_of_channels - 1, num_of_channels)
X,Y = np.meshgrid(x,y)
# Iterate through all of the field types and produce one plot for each but share the X axis
for idx, field_name in enumerate(field_names):
# Create this sub-plot
subPlotIndex = GetSubPlotIndex(len(field_names), 1, idx + 1)
ax = plt.subplot(subPlotIndex)
if idx is 0:
ax.set_title('Raw Data Time Series')
# Set the axis scale to exactly meet the limits of the data set.
ax.set_autoscale_on(False)
plt.axis([x[0], x[data_entries-1], 0, num_of_channels - 1])
# Set up the colour palette used to render the data.
# Make bad results (those that are masked) invisible so the background shows instead.
palette = plt.cm.get_cmap('autumn')
palette.set_bad(alpha=0.0)
ax.set_axis_bgcolor('black') # Set the background to zero
# Grab the data and transpose it so we can stick it in the time series running along the X axis.
firstFftCol = df.columns.get_loc(field_name + "%03d"%(0))
lastFftCol = df.columns.get_loc(field_name + "%03d"%(num_of_channels - 1))
data = df.ix[:,firstFftCol:lastFftCol]
data = data.T # Transpose so that time runs along the X axis and bin index is on the Y
# Mask off data with zero's so that it doesn't obscure the data we're actually interested in.
data = np.ma.masked_where(data == 0.0, data)
# Actually create the data mesh so we can plot it
z_min, z_max = data.min().min(), data.max().max()
p = ax.pcolormesh(X,Y, data, cmap=palette, vmin=z_min, vmax=z_max)
# Render it
plt.plot()
# Label this sub-plot
plt.ylabel(field_name)
# Sort out the color bar
fig = plt.gcf()
image_width = fig.get_size_inches()[0] * fig.dpi # size in pixels
colorbar_padding_width_in_pixels = 20
colorbar_padding = colorbar_padding_width_in_pixels/image_width
plt.colorbar(p, pad=colorbar_padding)
# Label the plot
plt.xlabel('Time (ms)')
# Record the result
plt.savefig('test.png', edgecolor='none', transparent=False, bbox_inches='tight')
plt.tight_layout()
if __name__ == '__main__':
data_entries = 30000 # Large values here cause issues
num_of_channels = 255
fields_to_plot = ('random', 'periodic')
data = GetData(data_entries, num_of_channels)
width_in_pixels = len(data.index)+200
additional_vertical_space_per_plot = 50
num_of_plots = len(fields_to_plot)
height_in_pixels = (num_of_channels+additional_vertical_space_per_plot)*num_of_plots
dpi = 80 # The default according to the documentation.
fig = plt.figure(1,figsize=(width_in_pixels/dpi, height_in_pixels/dpi), dpi=dpi)
PlotData(data, num_of_channels, fields_to_plot)
最后的秘诀是:
plt.tight_layout()
来电之前立即添加plt.savefig
。bbox_inches='tight'
添加到plt.savefig
来电。, pad=colorbar_padding
后添加colorbar_padding
。