如何在PyO3中实现python运算符

时间:2019-12-06 09:05:32

标签: python rust pyo3

我正在尝试为我的数学库在rust中实现向量类。

#[pyclass]
struct Vec2d {
    #[pyo3(get, set)]
    x: f64,
    #[pyo3(get, set)]
    y: f64
}

但是我不知道如何重载标准运算符(+,-,*,/)

我尝试实现std :: ops的Add特性,但是没有运气

impl Add for Vec2d {
    type Output = Vec2d;
    fn add(self, other: Vec2d) -> Vec2d {
        Vec2d{x: self.x + other.x, y: self.y + other.y }
    }
}

我还尝试将__add__方法添加到#[pymethods]块中

fn __add__(&self, other: & Vec2d) -> PyResult<Vec2d> {
    Ok(Vec2d{x: self.x + other.x, y: self.y + other.y })
}

但仍然无法正常工作。

使用第二种方法,我可以看到该方法存在,但是python不能将其识别为运算符重载

In [2]: v1 = Vec2d(3, 4)
In [3]: v2 = Vec2d(6, 7)
In [4]: v1 + v2
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-08104d7e1232> in <module>()
----> 1 v1 + v2

TypeError: unsupported operand type(s) for +: 'Vec2d' and 'Vec2d'

In [5]: v1.__add__(v2)
Out[5]: <Vec2d object at 0x0000026B74C2B6F0>

2 个答案:

答案 0 :(得分:3)

根据PyO3文档,

Python的对象模型为不同的对象行为定义了几种协议,例如序列,映射或数字协议。 PyO3为它们各自定义了不同的特征。 要提供特定的python对象行为,您需要为您的结构实现特定的特征。

重要说明,每个协议实现模块都必须使用#[{pyproto] 属性进行注释。

__add____sub__等在PyNumberProtocol特性中定义。

因此,您可以为自己的PyNumberProtocol结构实现Vec2d,以重载标准操作。

#[pyproto]
impl PyNumberProtocol for Vec2d {
    fn __add__(&self, other: & Vec2d) -> PyResult<Vec2d> {
            Ok(Vec2d{x: self.x + other.x, y: self.y + other.y })
   }
}

答案 1 :(得分:3)

我将添加此答案,以免其他人像我一样搜索小时。

使用@Abdul Niyas P M提供的答案,我遇到以下错误:

error: custom attribute panicked
  --> src/vec2.rs:49:1
   |
49 | #[pyproto]
   | ^^^^^^^^^^
   |
   = help: message: fn arg type is not supported

事实证明,此错误消息隐藏了两个问题。 第一个问题是__add__应该使用值而不是引用,因此我们在&self之前删除了Vec2。 这使我们摆脱了错误消息:

error[E0277]: the trait bound `&vec2::Vec2: pyo3::pyclass::PyClass` is not satisfied
   --> src/vec2.rs:47:1
    |
47  | #[pyproto]
    | ^^^^^^^^^^ the trait `pyo3::pyclass::PyClass` is not implemented for `&vec2::Vec2`
    | 

当我们指定self的类型时,第二个问题可以发现:

// DOES NOT COMPILE
#[pyproto]
impl PyNumberProtocol for Vec2 {
    fn __add__(self: Vec2, other: Vec2) -> PyResult<Vec2> {
        Ok(Vec2{x: self.x + other.x, y: self.y + other.y})
    }
}

这无法与错误消息一起编译

error[E0185]: method `__add__` has a `self: <external::vec3::Vec3 as pyo3::class::number::PyNumberAddProtocol<'p>>::Left` declaration in the impl, but not in the trait
  --> src/external/vec2.rs:49:5
   |
49 |     fn __add__(self: Vec2, other: Vec2) -> PyResult<Vec2> {
   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ `self: <vec2::Vec2 as pyo3::class::number::PyNumberAddProtocol<'p>>::Left` used in impl
   |
   = note: `__add__` from trait: `fn(<Self as pyo3::class::number::PyNumberAddProtocol<'p>>::Left, <Self as pyo3::class::number::PyNumberAddProtocol<'p>>::Right) -> <Self as pyo3::class::number::PyNumberAddProtocol<'p>>::Result`

这导致我们找到最终的工作解决方案(截至2020年6月):

#[pyproto]
impl PyNumberProtocol for Vec2 {
    fn __add__(lhs: Vec2, rhs: Vec2) -> PyResult<Vec2> {
        Ok(Vec2{x: lhs.x + rhs.x, y: lhs.y + rhs.y})
    }
}

哪个可以在Rust每晚1.45下成功编译,并且已经过检查可以在Python上运行。

也可以使用其他类型的rhs

#[pyproto]
impl PyNumberProtocol for Vec2 {
    fn __mul__(lhs: Vec2, rhs: f64) -> PyResult<Vec2> {
        Ok(Vec3{x: lhs.x * rhs, y: lhs.y * rhs})
    }
}

还要注意,拥有self并不总是一个问题:

#[pyproto]
impl PyNumberProtocol for Vec2 {
    fn __neg__(self) -> PyResult<Vec2> {
        Ok(Vec3{x: -self.x, y: -self.y})
    }
}