我有一个函数带有参数y_0
,该参数可能是列表或数组,也可能只是单个浮点数。我希望它能在y_0
是单个浮点的情况下工作,因此我尝试在下面的代码中使用np.asarray(y_0)
,这样,如果只有一个项目,则循环仍然可以工作。但是,我得到了错误TypeError: iteration over a 0-d array
。我可以使用if语句检查它是否为单例,并采取适当的措施。但是,我很好奇是否有一种方法可以对单个对象进行迭代?
def vf_grapher(fn, t_0, t_n, dt, y_0, lintype='-r', sup_title=None,
title=None, xlab=None, ylab=None):
t = np.arange(t_0, t_n, dt)
y_min = .0
y_max = .0
fig, axs = plt.subplots()
fig.suptitle(sup_title)
axs.set_title(title)
axs.set_ylabel(ylab)
axs.set_xlabel(xlab)
for iv in np.asarray(y_0):
soln = rk4(dt, t, fn, iv)
plt.plot(t, soln, lintype)
if y_min > np.min(soln):
y_min = np.min(soln)
if y_max < np.max(soln):
y_max = np.max(soln)
为使工作示例最少,请包括以下功能:
def rk4(dt, t, field, y_0):
"""
:param dt: float - the timestep
:param t: array - the time mesh
:param field: method - the vector field y' = f(t, y)
:param y_0: array - contains initial conditions
:return: ndarray - solution
"""
# Initialize solution matrix. Each row is the solution to the system
# for a given time step. Each column is the full solution for a single
# equation.
y = np.asarray(len(t) * [y_0])
for i in np.arange(len(t) - 1):
k1 = dt * field(t[i], y[i])
k2 = dt * field(t[i] + 0.5 * dt, y[i] + 0.5 * k1)
k3 = dt * field(t[i] + 0.5 * dt, y[i] + 0.5 * k2)
k4 = dt * field(t[i] + dt, y[i] + k3)
y[i + 1] = y[i] + (k1 + 2 * k2 + 2 * k3 + k4) / 6
return y
if __name__ == '__main__':
def f(t, x): return x**2 - x
vf_grapher(f, 0, 4, 0.1, (-0.9, 0.5, 1.01), xlab='t', ylab='x(t)',
sup_title=r'Solution Field for $\dot{x} = x^2 - x$')
答案 0 :(得分:2)
您可以使用ndmin
的{{1}}参数来确保数组实际上是可迭代的:
np.array
np.array(y_0, ndmin=1, copy=False)
只是np.asarray
的别名,它以不同的方式设置了一些默认参数。
np.array
可用于用单位尺寸填充形状。这有帮助,因为通常零维数组等效于标量。这样做的一个烦人的副作用是它们不是可迭代的。 ndmin
意味着应将标量输入视为一维,一元数组,这就是您要查找的内容。
ndmin=1
只是告诉numpy按原样使用现有数组,而不是进行复制。这样,如果用户传入一个实际的数组(而不是列表或标量),则将使用该数组而不会重复数据。我经常将此参数与copy=False
配对,该参数也将通过subok=True
的子类传递,而不进行复制。
答案 1 :(得分:0)
我不确定为什么可以遍历单个对象的列表而不是单个对象的数组,但是我在另一个问题的答案中找到了一种确定项目是否可迭代的方法: https://stackoverflow.com/a/1952481/3696204
然后,我尝试了一下,除了如下所示的块:
try:
iter(y_0)
except TypeError:
y_0 = list([y_0])
for iv in y_0:
soln = rk4(dt, t, fn, iv)
plt.plot(t, soln, lintype)
if y_min > np.min(soln):
y_min = np.min(soln)
if y_max < np.max(soln):
y_max = np.max(soln)
谢谢大家的有用评论。