Skip to content

Commit

Permalink
Test rewired deprecated API
Browse files Browse the repository at this point in the history
  • Loading branch information
relf committed Jan 27, 2025
1 parent 43b1829 commit 0192092
Showing 1 changed file with 28 additions and 31 deletions.
59 changes: 28 additions & 31 deletions algorithms/linfa-svm/src/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,16 @@ pub mod tests {
use linfa::dataset::Dataset;
use linfa::metrics::SingleTargetRegression;
use linfa::traits::{Fit, Predict};
use ndarray::Array;
use linfa::DatasetBase;
use ndarray::{Array, Array1, Array2};

fn _check_model(model: Svm<f64, f64>, dataset: &DatasetBase<Array2<f64>, Array1<f64>>) {
println!("{}", model);
let predicted = model.predict(dataset.records());
let err = predicted.mean_squared_error(&dataset).unwrap();
println!("err={}", err);
assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);
}

#[test]
fn test_epsilon_regression_linear() -> Result<()> {
Expand All @@ -219,13 +228,15 @@ pub mod tests {
.c_svr(5., None)
.linear_kernel()
.fit(&dataset)?;
_check_model(model, &dataset);

println!("{}", model);

let predicted = model.predict(dataset.records());
let err = predicted.mean_squared_error(&dataset).unwrap();
println!("err={}", err);
assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);
// Old API
#[allow(deprecated)]
let model2 = Svm::params()
.c_eps(5., 1e-3)
.linear_kernel()
.fit(&dataset)?;
_check_model(model2, &dataset);

Ok(())
}
Expand All @@ -242,14 +253,15 @@ pub mod tests {
.nu_svr(0.5, Some(1.))
.linear_kernel()
.fit(&dataset)?;
_check_model(model, &dataset);

println!("{}", model);

let predicted = model.predict(&dataset);
let err = predicted.mean_squared_error(&dataset).unwrap();
println!("err={}", err);
assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);

// Old API
#[allow(deprecated)]
let model2 = Svm::params()
.nu_eps(0.5, 1e-3)
.linear_kernel()
.fit(&dataset)?;
_check_model(model2, &dataset);
Ok(())
}

Expand All @@ -259,22 +271,14 @@ pub mod tests {
.into_shape((100, 1))
.unwrap();
let sin_curve = records.mapv(|v| v.sin()).into_shape((100,)).unwrap();

let dataset = Dataset::new(records, sin_curve);

let model = Svm::params()
.c_svr(100., Some(0.1))
.gaussian_kernel(10.)
.eps(1e-3)
.fit(&dataset)?;

println!("{}", model);

let predicted = model.predict(&dataset);
let err = predicted.mean_squared_error(&dataset).unwrap();
println!("err={}", err);
assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);

_check_model(model, &dataset);
Ok(())
}

Expand All @@ -290,14 +294,7 @@ pub mod tests {
.polynomial_kernel(1., 3.)
.eps(1e-3)
.fit(&dataset)?;

println!("{}", model);

let predicted = model.predict(&dataset);
let err = predicted.mean_squared_error(&dataset).unwrap();
println!("err={}", err);
assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2);

_check_model(model, &dataset);
Ok(())
}
}

0 comments on commit 0192092

Please sign in to comment.