将scipy对象保存到文件

时间:2016-01-15 13:17:24

标签: python numpy scipy interpolation

我想将interpolator生成的对象scipy.interpolate.InterpolatedUnivariateSpline保存到文件中,以便之后加载并使用它。 这是控制台上的结果:

>>> interpolator
 <scipy.interpolate.fitpack2.InterpolatedUnivariateSpline object at 0x11C27170>
np.save("interpolator",np.array(interpolator))
>>> f = np.load("interpolator.npy")
>>> f
array(<scipy.interpolate.fitpack2.InterpolatedUnivariateSpline object at 0x11C08FB0>, dtype=object)

这些是尝试使用带有通用值的加载插值器f的结果:

>>>f(10)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'numpy.ndarray' object is not callable

或:

>>> f[0](10)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: too many indices for array

如何正确保存/加载?

2 个答案:

答案 0 :(得分:5)

interpolator对象不是数组,因此np.save已将其包装在object数组中。并且它返回pickle以保存不是数组的元素。所以你得到一个带有一个对象的0d数组。

用简单的字典对象来说明:

In [280]: np.save('test.npy',{'one':1})
In [281]: x=np.load('test.npy')
In [282]: x
Out[282]: array({'one': 1}, dtype=object)
In [283]: x[0]
...
IndexError: 0-d arrays can't be indexed
In [284]: x[()]
Out[284]: {'one': 1}
In [285]: x.item()
Out[285]: {'one': 1}
In [288]: x.item()['one']
Out[288]: 1

因此,item[()]将从数组中检索此对象。然后,您应该像save之前那样使用它。

使用您自己的pickle电话很好。

答案 1 :(得分:2)

看起来像numpy.save然后numpy.load将scipy InterpolatedUnivariateSpline对象转换为numpy对象。 Numpy save / load apparently有一个allow_pickle=True输入,应该保留对象信息。这在我的numpy版本(1.9.2)中并没有出现,我想也许你的版本也是如此。使用spl=numpy.load("file")时,类型信息会丢失,因此作为方法的调用spl将失败。由于numpy save主要是为二进制数据数组而设计的,因此最通用的解决方案可能是使用pickle。作为最低限度的例子,

import matplotlib.pyplot as plt
from scipy.interpolate import InterpolatedUnivariateSpline
import numpy as np
try:
    import cPickle as pickle
except ImportError:
    import pickle

x = np.linspace(-3, 3, 50)
y = np.exp(-x**2) + 0.1 * np.random.randn(50)
spl = InterpolatedUnivariateSpline(x, y)
plt.plot(x, y, 'ro', ms=5)


xs = np.linspace(-3, 3, 1000)

#Plot before save
plt.plot(xs, spl(xs), 'g', lw=3, alpha=0.7)

#Save, load and plot again (NOTE CAUSES ERROR)
#np.save("interpolator",spl)
#spl_loaded = np.load("interpolator.npy")
#plt.plot(xs, spl_loaded(xs), 'k--', lw=3, alpha=0.7)

#Pickle, unpickle and then plot again
with open('interpolator.pkl', 'wb') as f:
    pickle.dump(spl, f)
with open('interpolator.pkl', 'rb') as f:
    spl_loaded = pickle.load(f)
plt.plot(xs, spl_loaded(xs), 'k--', lw=3, alpha=0.7)

plt.show()