我想在django中使用numpy数组字段,以便我可以做这样的事情
from example.models import Series
import numpy as np
array = np.array([1, 2, 3])
model = Series.objects.create(id=1, array=array)
model = Series.objects.get(id=1)
assert np.array_equal(array, model.array)
本质上,该字段应该将numpy数组序列化为二进制并自动反序列化。目前,我只是这样做:
import base64
import numpy as np
from django.db import models
class Series(models.Model):
id = models.IntegerField(primary_key=True, unique=True)
array = models.BinaryField()
def get_array():
return np.frombuffer(base64.decodebytes(self.array), dtype=np.float32)
def set_array(array):
self.array = base64.b64encode(array)
如果这是一个可重用的字段,我更喜欢它,因为我有许多需要存储numpy数组的模型。例如:
class Series(models.Model):
array = NumpyArrayField(dtype=np.float32)
那么,我怎样才能编写一个完成此任务的NumpyArrayField
类?
我尝试了以下操作(复制BinaryField的源代码)
import base64
import numpy as np
from django.db import models
class NumpyArrayField(models.Field):
empty_values = [None]
def __init__(self, dtype, *args, **kwargs):
self.dtype = dtype
super(NumpyArrayField, self).__init__(*args, **kwargs)
def deconstruct(self):
name, path, args, kwargs = super(NumpyArrayField, self).deconstruct()
kwargs['dtype'] = self.dtype
return name, path, args, kwargs
def get_internal_type(self):
return 'NumpyArrayField'
def get_placeholder(self, value, compiler, connection):
return connection.ops.binary_placeholder_sql(value)
def get_default(self):
if self.has_default() and not callable(self.default):
return self.default
default = super(NumpyArrayField, self).get_default()
if default == '':
return b''
return default
def get_db_prep_value(self, value, connection, prepared=False):
value = super(NumpyArrayField, self).get_db_prep_value(value, connection, prepared)
value = base64.b64encode(value)
if value is not None:
return connection.Database.Binary(value)
return value
def value_to_string(self, obj):
return base64.b64encode(obj).decode('ascii')
def to_python(self, value):
return np.frombuffer(base64.decodebytes(value), dtype=self.dtype)
class Series(models.Model):
id = models.IntegerField(primary_key=True, unique=True)
array = NumpyArrayField(dtype=np.int32)
迁移运行良好,但我收到django.db.utils.OperationalError: table example_series has no column named array
错误。
答案 0 :(得分:0)
我用MySQL将numpy数组保存到Django模型中,
从django.db导入模型
np_field = models.BinaryField()
np_bytes = pickle.dumps(np_array)
np_base64 = base64.b64encode(np_bytes)
model.np_field = np_base64
np_bytes = base64.b64decode(model.np_field)
np_array = pickle.loads(np_bytes)