From 01920925cb601baf5dae01eb99a3c873b490f52f Mon Sep 17 00:00:00 2001 From: relf Date: Mon, 27 Jan 2025 17:13:03 +0100 Subject: [PATCH] Test rewired deprecated API --- algorithms/linfa-svm/src/regression.rs | 59 ++++++++++++-------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/algorithms/linfa-svm/src/regression.rs b/algorithms/linfa-svm/src/regression.rs index 116ddc1cc..77aa50a6f 100644 --- a/algorithms/linfa-svm/src/regression.rs +++ b/algorithms/linfa-svm/src/regression.rs @@ -206,7 +206,16 @@ pub mod tests { use linfa::dataset::Dataset; use linfa::metrics::SingleTargetRegression; use linfa::traits::{Fit, Predict}; - use ndarray::Array; + use linfa::DatasetBase; + use ndarray::{Array, Array1, Array2}; + + fn _check_model(model: Svm, dataset: &DatasetBase, Array1>) { + println!("{}", model); + let predicted = model.predict(dataset.records()); + let err = predicted.mean_squared_error(&dataset).unwrap(); + println!("err={}", err); + assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2); + } #[test] fn test_epsilon_regression_linear() -> Result<()> { @@ -219,13 +228,15 @@ pub mod tests { .c_svr(5., None) .linear_kernel() .fit(&dataset)?; + _check_model(model, &dataset); - println!("{}", model); - - let predicted = model.predict(dataset.records()); - let err = predicted.mean_squared_error(&dataset).unwrap(); - println!("err={}", err); - assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2); + // Old API + #[allow(deprecated)] + let model2 = Svm::params() + .c_eps(5., 1e-3) + .linear_kernel() + .fit(&dataset)?; + _check_model(model2, &dataset); Ok(()) } @@ -242,14 +253,15 @@ pub mod tests { .nu_svr(0.5, Some(1.)) .linear_kernel() .fit(&dataset)?; + _check_model(model, &dataset); - println!("{}", model); - - let predicted = model.predict(&dataset); - let err = predicted.mean_squared_error(&dataset).unwrap(); - println!("err={}", err); - assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2); - + // Old API + #[allow(deprecated)] + let model2 = Svm::params() + .nu_eps(0.5, 1e-3) + .linear_kernel() + .fit(&dataset)?; + _check_model(model2, &dataset); Ok(()) } @@ -259,7 +271,6 @@ pub mod tests { .into_shape((100, 1)) .unwrap(); let sin_curve = records.mapv(|v| v.sin()).into_shape((100,)).unwrap(); - let dataset = Dataset::new(records, sin_curve); let model = Svm::params() @@ -267,14 +278,7 @@ pub mod tests { .gaussian_kernel(10.) .eps(1e-3) .fit(&dataset)?; - - println!("{}", model); - - let predicted = model.predict(&dataset); - let err = predicted.mean_squared_error(&dataset).unwrap(); - println!("err={}", err); - assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2); - + _check_model(model, &dataset); Ok(()) } @@ -290,14 +294,7 @@ pub mod tests { .polynomial_kernel(1., 3.) .eps(1e-3) .fit(&dataset)?; - - println!("{}", model); - - let predicted = model.predict(&dataset); - let err = predicted.mean_squared_error(&dataset).unwrap(); - println!("err={}", err); - assert!(predicted.mean_squared_error(&dataset).unwrap() < 1e-2); - + _check_model(model, &dataset); Ok(()) } }