我正在使用Rust和ndarray通过PragProg的Programming Machine Learning进行工作。我有以下代码,并且在w
中遇到了train()
类型的编译错误,但是它看起来与predict()
中的工作返回类型非常相似,我无法确定找出我的类型有什么问题。
我看过的其他问题均未解决我的特定问题。我尝试使用array![0.]
和Array::zeros(1)
,但都没有正确的类型。
#[macro_use]
extern crate ndarray;
use ndarray::prelude::*;
fn main() {
println!("Hello, world!");
let x: Array<f32, ndarray::Dim<[usize; 2]>> = array![[13., 2., 14., 23.]];
let y: Array<f32, ndarray::Dim<[usize; 2]>> = array![[33., 16., 32., 51.]];
let q: Array<f32, ndarray::Dim<[usize; 1]>> = train(&x, &y, 1000, 0.01);
println!("{}", q);
}
fn predict<D, E>(x: &Array<f32, D>, w: &Array<f32, E>) -> Array<f32, D>
where
D: Dimension,
E: Dimension,
{
return x * w;
}
fn loss<D, E>(x: &Array<f32, D>, y: &Array<f32, D>, w: &Array<f32, E>) -> f32
where
D: Dimension,
E: Dimension,
{
let mut i = predict(x, w);
i = i - y;
i = i.mapv(|a| a.powi(2));
return i.sum() / i.len() as f32;
}
fn train<D, E>(
x: &Array<f32, D>,
y: &Array<f32, D>,
iterations: i32,
learning_rate: f32,
) -> Array<f32, E>
where
D: Dimension,
E: Dimension,
{
let w = array![0.];
let mut current_loss;
for i in 0..iterations {
current_loss = loss(x, y, &w);
println!("Iteration {} => Loss: {}", i, current_loss);
if loss(x, y, &(&w + learning_rate)) < current_loss {
w += learning_rate;
} else if loss(x, y, &(&w - learning_rate)) < current_loss {
w -= learning_rate;
} else {
break;
}
}
println!("{}", w);
return w;
}
由于predict
可以编译,所以我希望train
也可以编译,但是我得到了:
error[E0308]: mismatched types
--> src/main.rs:59:12
|
39 | ) -> Array<f32, E>
| ------------- expected `ndarray::ArrayBase<ndarray::OwnedRepr<f32>, E>` because of return type
...
59 | return w;
| ^ expected type parameter, found struct `ndarray::Dim`
|
= note: expected type `ndarray::ArrayBase<_, E>`
found type `ndarray::ArrayBase<_, ndarray::Dim<[usize; 1]>>`