我创建了一个名为“Parray”的ndarray子类,它有两个参数:p和维度。它本身很好用。现在,我想创建一个名为SirPlotsAlot的类,它继承了Parray而没有花哨的 new 和 array_finalize 等。
import numpy as np
class Parray(np.ndarray):
def __new__(self, p = Parameters(), dimensionality = 2):
print "Initializing Parray with initial dimensionality %s..." % dimensionality
self.p = p # store the parameters
if dimensionality == 2:
shape = (p.nx, p.ny)
self.pshape = shape
elif dimensionality == 3:
shape=(p.nx, p.ny, p.nx)
self.pshape = shape
else:
raise NotImplementedError, "dimensionality must be 2 or 3"
# ...Set other variables (ellided)
subarr = np.ndarray.__new__(self, shape, dtype, buffer, offset, strides, order)
subarr[::] = np.zeros(self.pshape) # initialize to zero
return subarr
...
class SirPlotsAlot(Parray):
def __init__(self, p = Parameters(), dimensions = 3):
super(SirPlotsAlot, self).__new__(p, dimensions) # (1)
我程序中的对象通过来回传递对象p = Parameters()来共享参数集。
现在,当我输入(文件是auxiliary.py)时:
import auxiliary
from parameters import Parameters
p = Parameters()
s = auxiliary.SirPlotsAlot(p, 3)
期望得到一个很好的“初始维度3初始化Parray”,我得到“2”。但如果我输入:
import auxiliary
s = auxiliary.SirPlotsAlot()
我得到了
---> 67 shape = (p.nx, p.ny)
"AttributeError: 'int' object has no attribute 'nx'"
它认为“p”是一个int,它不是。如果我玩弄它,我会得到很多看似无关的奇怪错误。它认为它是“2”。我完全迷失了。
我尝试过使用和不使用#(1)评论(超级电话)。
游戏中的其他错误包括“AttributeError:'list'对象没有属性'p'”,“TypeError: new ()正好接受2个参数(给定1个)”,“ValueError:需要超过0个值才能解压缩“(我用* args替换了 new 的参数,这是我不太了解的事情。)
答案 0 :(得分:1)
我会回应一下,然后说“不要使用__new__
”。您的Parray.__new__
方法看起来更像是初始化,应该使用__init__
,就像它的子类一样。
答案 1 :(得分:0)
已经十年了,我早就离开了这个项目,但我通过创建辅助函数来创建新类并设置它们来解决这个问题。在下面的代码示例中,请参阅文件底部的定义。我导入并使用了这些。
建议 Matthew Schinckel 指出 __new__
应该在 __init__
运行时已经被调用,并告知其他人的想法。
# -*- coding: utf-8 -*-
"""
Era's Plotting Functionality. This module exports SirPlotsAlot and company:
class SirPlotsAlot: Array with 2D, 3D, animated plotting capability, and a pyrism Parameteres object.
def NewSirPlotsAlot(p, dimensionality): returns SirPlotsAlot, but doesn't need explicit parameters
def returns_SirPlotsAlot: decorator force ndarray-returning function to return SirPlotsAlot instead.
Created on Thu Jul 12 18:46:15 2012
@author: Era
"""
# SirPlotsAlot
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import matplotlib.pyplot as pyplot
import numpy as np
import scipy as scipy
import logging
lprint = logging.getLogger('pyrism')
class SirPlotsAlot(np.ndarray):
"""
An array with 2D, 3D, animated plotting capability, and a pyrism Parameters object.
Inherits: numpy's ndarray
Input:
dimensionality: An int. The dimensions of the ndarray. Can be changed later.
p: A Parameters object.
"""
# class variables
currentSlice = 0 # for _updateSlice and animated plots
def __new__(cls, shape):
"""
Creates a new SirPlotsAlot for us to use.
SirPlotsAlot inherits ndarray. ndarray is written in C, and needs an extra\
method called __new__ to help it.
Args:
shape: A tuple of ints. The shape of the underlying ndarray.
Returns:
an ndarray
Author / Date:
Erasmus Alcarin / January 23rd, 2012
Erasmus Alcarin / July 13th, 2012
"""
### Specify the exact parameters of the array this class implements
dtype=float # dtype: data type. Optional
# Any object can be interpreted as
# a numpy datatype
buffer=None # buffer: object exposing buffer interface. Optional
# Used to fill array with data
offset=0 # offset: int. Optional.
# offset of array in data buffer
strides=None # strides : tuple of ints. Optional
# Strides of data in memory
order=None # order : {'C', 'F'}. Optional
# Row-major or column-major order.
# Instantiate new ndarray (this class). Temporarily called sub_array.
subarr = np.ndarray.__new__(cls, # cls is crucial.
# it creates a ndarray that
# is of type THIS CLASS
# instead of type ndarray
shape, dtype, buffer, offset, strides, order)
# Return the successfully created instance for this class to use!
return subarr
def __init__(self, shape):
"""Says hello!
Args:
shape: A tuple of ints. The shape of the underlying ndarray.
Returns:
None
Author:
Erasmus Alcarin / January 23rd, 2012
Erasmus Alcarin / July 13th, 2012
"""
lprint.debug("Ah, kind sir! Thy bidding be done!")
def __array_finalize__(self, obj):
"""Allow inheritance of ndarray's unary(?) operations.
Purpose: ndarray has a lot of functions which let you interact
with it (all its awesome features, specifically views
and so-called "new-from-template": that is, slices).
This function tells python that our class also gets
to use all of those nifty "unary" features!
Args:
obj: Another object. For example, this function is called if we type
myArr = myIntensityMap[1:]
(myArr is obj, and myIntensityMap is self)
Returns:
None
Author / Date:
Erasmus Alcarin / January 23rd, 2012
"""
if obj is None: return
def __array_wrap__(self, out_arr, context=None):
"""Allow inheritance of ndarray's binary(?) operations.
Purpose: ndarray has a lot of functions which let you interact
with it (all its awesome features, specifically array
adding, multiplying, etc.). This function tells python
to use all of those nifty "binary" features!
Args:
out_arr: What is returned in the operation which is being
performed.
context: A parameter which _array_wrap__ is specified to take. (optional)
If you know, update me!
Returns:
See ndarray.__array_wrap___()
Author / Date:
Erasmus Alcarin / January 23rd, 2012
"""
# Call ndarray's __array_wrap__ method.
return np.ndarray.__array_wrap__(self, out_arr, context).view(type(self))
def _enforceXD(self, X):
"""
a helper function returning true if this
SirPlotsAlot has dimensionality X, otherwise
raising a ValueError.
Args:
X: An int. The underlying ndarray dimensionality being tested for.
Returns:
True if this array has dimensionality X.
ValueError is raised otherwise.
"""
if self.shape.__len__() == X:
return True
else:
raise ValueError, "A %sD array was required. A %sD array was supplied." % (X, self.shape.__len__())
def _checkXD(self, X):
"""
a helper function returning true if this
SirPlotsAlot has dimensionality X, otherwise
raising a ValueError.
Args:
X: An int. The underlying ndarray dimensionality being tested for.
Returns:
True if this array has dimensionality X.
False otherwise.
"""
if self.shape.__len__() == X:
return True
else:
return False
def _checkLabelInfo(self, label = None):
"""
a helper utility function to decide which
of the accepted formats for plot label the user
has specified.
Args:
label: The user's input (Valid formats are String, Tuple)
Returns:
Nothing
"""
if len(label) >= 1 and type(label[0]) == str:
pyplot.title(label[0])
if len(label) >= 2 and type(label[1]) == str:
pyplot.xlabel(label[1])
if len(label) >= 3 and type(label[2]) == str:
pyplot.ylabel(label[2])
if len(label) >= 4 and type(label[3]) == str:
pyplot.zlabel(label[3])
def _add_labels(self, label = None, caller_label = 'none'):
"""
A utility function to quickly add labels to
any of the graphing utilities embedded in
SirPlotsAlot.
Args:
label: (Str, Tuple) The label being supplied by the user.
caller_label: A string. Each plotting function has its own
axes to label. This identifies the plotting function.
Returns:
Nothing
"""
if label == None:
raise ValueError, "_add_labels violated"
else:
lprint.debug("going on to labelling")
if type(label) == str:
pyplot.title(label)
elif type(label) == tuple:
self._checkLabelInfo(label)
elif hasattr(self, 'caller_label'):
label = getattr(self, caller_label)
if type(label) == str:
pyplot.title(label)
elif type(label) == tuple:
self._checkLabelInfo(label)
else:
print getattr(self, caller_label)
raise ValueError, "_add_labels requires string or tuple of strings"
def _updateSlice(self):
"""
a helper function for animate2D(), this controls the
progression (speed, sampling) of the animation by
returning the next image to be presented in the animation.
Args:
None
Returns:
2D slice of this array.
"""
if self._enforceXD(3):
self.currentSlice += 1
return self[self.currentSlice]
def plot1D(self, label = None):
"""
Plot 1-axis SirPlotsAlot in 2D, plotting array contents as y (up).
Args:
label: String or tuple labelling the plot.
Returns:
Nothing
"""
if self._enforceXD(1):
pyplot.figure()
# self._add_labels(label, 'plot1D_label')
pyplot.plot(self)
pyplot.show()
def plot2D(self, label = None):
"""
Plot 2-axis SirPlotsAlot in 2D, plotting array contents as color.
Args:
label: String or tuple labelling the plot.
Returns:
Nothing
"""
if self._enforceXD(2):
# Do not produce huge output
#lprint.debug("We're plotting this up:\n%s" % self)
lprint.debug("We're plotting you some goodies!")
fig = pyplot.figure()
if type(label) == str: # if label is supplied, apply it.
pyplot.title(label)
elif hasattr(self, 'plot2D_label'):
pyplot.title(self.plot2D_label)
plot = pyplot.imshow(self)
fig.colorbar(plot)
#colorbar.ax.set_yticklabels(["%.2f" % self.min(), '0', "%.2f" % self.max()])
pyplot.gca().invert_yaxis()
pyplot.xlabel('x')
pyplot.ylabel('y')
pyplot.show()
def save_plot2D(self, file = None, label = None, cbar_ticks = None):
"""
Saves plot of 2-axis SirPlotsAlot in 2D, plotting array contents as color,
in .png format.
Args:
file: A string. The filename to save to. Default: ``output``
label: String or tuple labelling the plot.
cbar_ticks: Colorbar ticks for plot. Default: auto.
Returns:
Nothing
"""
if self._enforceXD(2):
# Do not produce huge output
#lprint.debug("We're plotting this up:\n%s" % self)
lprint.debug("We're plotting you some goodies!")
fig = pyplot.figure()
if label is not None:
self._add_labels(label, 'plot2D_label')
plot = pyplot.imshow(self)
if cbar_ticks == None:
fig.colorbar(plot)
else:
cbar = fig.colorbar(plot, ticks=cbar_ticks) # Numbers
cbar.ax.set_yticklabels(map(str, cbar_ticks)) # Label
pyplot.gca().invert_yaxis()
if file == None:
file = 'output'
pyplot.savefig(file)
# nice!
def plot3D(self, label = None):
"""
Plots 2-axis SirPlotsAlot in 3D, plotting array contents as 3rd dimension (up).
Args:
label: String or tuple labelling the plot.
Returns:
Nothing
"""
if self._enforceXD(2):
# Do not produce huge output
#lprint.debug("We're plotting this up:\n%s" % self)
lprint.debug("We're plotting you some goodies!")
# make grid from min to max with interval nx
x = scipy.linspace(0, self.shape[1], self.shape[1])
y = scipy.linspace(0, self.shape[0], self.shape[0])
[x, y] = scipy.meshgrid(x, y) # this is the same as make_2d
fig = pyplot.figure()
if type(label) == str: # if label is supplied, apply it.
pyplot.title(label)
elif hasattr(self, 'plot3D_label'):
pyplot.title(self.plot3D_label)
ax = Axes3D(fig) # make a 3D axis
ax.plot_surface(x, y, self)
pyplot.xlabel('x')
pyplot.ylabel('y')
pyplot.show()
def plot3D_2(self, label = None):
"""
Plots 2-axis SirPlotsAlot in 3D, plotting array contents as 3rd dimension (up),
with contours projected onto each 2D cross-section of the 3D plot.
Args:
label: String or tuple labelling the plot.
Returns:
Nothing
"""
if self._enforceXD(2):
# Do not produce huge output
#lprint.debug("We're plotting this up:\n%s" % self)
lprint.debug("We're plotting you some goodies!")
# make grid from min to max with interval nx
x = scipy.linspace(0, self.shape[1], self.shape[1])
y = scipy.linspace(0, self.shape[0], self.shape[0])
[x, y] = scipy.meshgrid(x, y)
fig = pyplot.figure()
if type(label) == str:
pyplot.title(label)
elif hasattr(self, 'plot3D_2_label'):
pyplot.title(self.plot3D_2_label)
ax = fig.gca(projection='3d')
ax.plot_surface(x, y, self, rstride=8, cstride=8, alpha=0.3)
ax.contour(x, y, self, zdir='z', offset=self.min())
ax.contour(x, y, self, zdir='x', offset=0)
ax.contour(x, y, self, zdir='y', offset=self.shape[0])
ax.set_xlabel('x')
ax.set_xlim(0, self.shape[1])
ax.set_ylabel('y')
ax.set_ylim(0, self.shape[0])
ax.set_zlabel('z')
ax.set_zlim(self.min(), self.max())
pyplot.show()
# the following is probably deprecated code for the above.
'''if self._enforceXD(2):
print "We're plotting this up:\n%s" % self
# make grid from min to max with interval nx
x = y = scipy.linspace(self.min(), self.max(), self.shape[0])
[x, y] = scipy.meshgrid(x, y) # this is the same as make_2d
fig = pyplot.figure()
if type(label) == str: # if label is supplied, apply it.
pyplot.title(label)
elif hasattr(self, 'plot3D_label'):
pyplot.title(self.plot3D_label)
ax = Axes3D(fig) # make a 3D axis
ax.plot_surface(x, y, self)
pyplot.show()
'''
def aniPlot2D(self):
"""
Generate successive 2D color plots using color for the data. Then play these
plots in series, creating an animation. Requires 3D SirPlotsAlot.
Args:
None
Returns:
Nothing
"""
self.tplot = 0
fig = pyplot.figure()
#x = np.arange(0, self.shape[1])
#y = np.arange(0, self.shape[0]).reshape(-1,1)
ims = []
imsappend = ims.append # optimization
for t in np.arange(self.shape[1]):
imsappend((pyplot.imshow(self[t]),))
animation.ArtistAnimation(fig, ims, interval=50, repeat_delay=3000, blit=True)
pyplot.show()
def NewSirPlotsAlot(shape = (512, 512)):
"""
Returns instance of SirPlotsAlot explicitly initiallized to all zeros;
arguments may be left unspecified.
Args:
shape: A tuple of ints. The shape of the underlying ndarray.
Returns:
SirPlotsAlot
Author / Date:
Erasmus Alcarin / July 13, 2012
"""
s = SirPlotsAlot(shape)
s[:] = np.zeros(s.shape)
lprint.info("SirPlotsAlot has been populated with zeros.")
return s
# Aliases for NewSirPlotsAlot
splot = NewSirPlotsAlot
def NewPsirPlotsAlot(dimensionality = 3, p = None):
"""
Returns instance of SirPlotsAlot explicitly initiallized to all zeros;
arguments may be left unspecified.
Args:
dimensionality: An int. Number of dimensions for array. (optional)
p: A Parameters object. Simulation parameters. (optional)
Returns:
SirPlotsAlot
Author / Date:
Erasmus Alcarin / July 13, 2012
"""
lprint.debug("Initializing SirPlotsAlot with initial dimensionality %s..." % dimensionality)
# NewSirPlotsAlot()
try:
import pyrism.parameters as par
except:
import sys
lprint.error("Use of pyrism as non-package detected. You must remain in the pyrism directory.")
import parameters as par
if p == None:
p = par.Parameters.Instance()
# extract size from parameters file, assuming size nx, ny
if dimensionality == 2:
shape = (p.ny, p.nx)
elif dimensionality == 3:
shape = (p.nx, p.ny, p.nx)
else:
raise NotImplementedError, "dimensionality must be 2 or 3"
# Make and Get object
s = SirPlotsAlot(shape)
s[:] = np.zeros(s.shape)
lprint.info("SirPlotsAlot has been populated with zeros.")
return s
# Aliases for NewPsirPlotsAlot
psplot = NewPsirPlotsAlot
def returns_SirPlotsAlot(fn):
"""
A decorator that changes an ndarray to a SirPlotsAlot by
means of the ndarray view function. (Returns SirPlotsAlot)
"""
def wrapped(*args, **kwargs):
return fn(*args, **kwargs).view(SirPlotsAlot)
return wrapped
# Aliases for returns_SirPlotsAlot
returns_splot = returns_SirPlotsAlot