得到“预期的类型参数,找到结构”,我在这里做错了什么?

时间:2019-05-06 23:55:20

标签: multidimensional-array rust

我正在使用Rust和ndarray通过PragProg的Programming Machine Learning进行工作。我有以下代码,并且在w中遇到了train()类型的编译错误,但是它看起来与predict()中的工作返回类型非常相似,我无法确定找出我的类型有什么问题。

我看过的其他问题均未解决我的特定问题。我尝试使用array![0.]Array::zeros(1),但都没有正确的类型。

#[macro_use]
extern crate ndarray;

use ndarray::prelude::*;

fn main() {
    println!("Hello, world!");
    let x: Array<f32, ndarray::Dim<[usize; 2]>> = array![[13., 2., 14., 23.]];
    let y: Array<f32, ndarray::Dim<[usize; 2]>> = array![[33., 16., 32., 51.]];

    let q: Array<f32, ndarray::Dim<[usize; 1]>> = train(&x, &y, 1000, 0.01);
    println!("{}", q);
}

fn predict<D, E>(x: &Array<f32, D>, w: &Array<f32, E>) -> Array<f32, D>
where
    D: Dimension,
    E: Dimension,
{
    return x * w;
}

fn loss<D, E>(x: &Array<f32, D>, y: &Array<f32, D>, w: &Array<f32, E>) -> f32
where
    D: Dimension,
    E: Dimension,
{
    let mut i = predict(x, w);
    i = i - y;
    i = i.mapv(|a| a.powi(2));
    return i.sum() / i.len() as f32;
}

fn train<D, E>(
    x: &Array<f32, D>,
    y: &Array<f32, D>,
    iterations: i32,
    learning_rate: f32,
) -> Array<f32, E>
where
    D: Dimension,
    E: Dimension,
{
    let w = array![0.];
    let mut current_loss;
    for i in 0..iterations {
        current_loss = loss(x, y, &w);
        println!("Iteration {} => Loss: {}", i, current_loss);

        if loss(x, y, &(&w + learning_rate)) < current_loss {
            w += learning_rate;
        } else if loss(x, y, &(&w - learning_rate)) < current_loss {
            w -= learning_rate;
        } else {
            break;
        }
    }
    println!("{}", w);
    return w;
}

由于predict可以编译,所以我希望train也可以编译,但是我得到了:

error[E0308]: mismatched types
  --> src/main.rs:59:12
   |
39 | ) -> Array<f32, E>
   |      ------------- expected `ndarray::ArrayBase<ndarray::OwnedRepr<f32>, E>` because of return type
...
59 |     return w;
   |            ^ expected type parameter, found struct `ndarray::Dim`
   |
   = note: expected type `ndarray::ArrayBase<_, E>`
              found type `ndarray::ArrayBase<_, ndarray::Dim<[usize; 1]>>`

0 个答案:

没有答案