From 733192ba77f129d6a730d237bf1d29a9710cb04c Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 25 Mar 2024 00:08:02 +0100 Subject: [PATCH 01/28] Define a Vector trait and use it in `fit` and `cblas::level1` Added benefits to the additional generality: - all functions are bound checked; - the complex numbers are handled correctly and are compatible with their standard Rust representation. --- Cargo.toml | 4 + examples/fitting.rs | 17 +- src/cblas.rs | 740 +++++++++++++++++++++++++++++++++----------- src/fit.rs | 152 ++++----- src/types/vector.rs | 82 +++++ 5 files changed, 718 insertions(+), 277 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fcee07cd..fdfef2f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,8 +13,10 @@ edition = "2021" [dependencies] sys = { path = "gsl-sys", package = "GSL-sys", version = "3.0.0" } paste = "1.0" +num-complex = { version = "0.4.5", optional = true } [features] +default = ["complex"] v2_1 = ["sys/v2_1"] v2_2 = ["sys/v2_2", "v2_1"] v2_3 = ["sys/v2_3", "v2_2"] @@ -23,6 +25,8 @@ v2_5 = ["sys/v2_5", "v2_4"] v2_6 = ["sys/v2_6", "v2_5"] v2_7 = ["sys/v2_7", "v2_6"] dox = ["v2_7", "sys/dox"] +# Enable complex number functions: +complex = ["dep:num-complex"] [package.metadata.docs.rs] features = ["dox"] diff --git a/examples/fitting.rs b/examples/fitting.rs index 13f5564f..3e83e1ff 100644 --- a/examples/fitting.rs +++ b/examples/fitting.rs @@ -6,28 +6,27 @@ extern crate rgsl; use rgsl::fit; -const N: usize = 4; - fn main() { - let x = &[1970., 1980., 1990., 2000.]; - let y = &[12., 11., 14., 13.]; - let w = &[0.1, 0.2, 0.3, 0.4]; + let x = [1970., 1980., 1990., 2000.]; + let y = [12., 11., 14., 13.]; + let w = [0.1, 0.2, 0.3, 0.4]; - let (c0, c1, cov00, cov01, cov11, chisq) = fit::wlinear(x, 1, w, 1, y, 1, N).unwrap(); + let (c0, c1, cov00, cov01, cov11, chisq) = fit::wlinear(&x, &w, &y).unwrap(); println!("# best fit: Y = {} + {} X", c0, c1); println!("# covariance matrix:"); println!("# [ {}, {}\n# {}, {}]", cov00, cov01, cov01, cov11); println!("# chisq = {}", chisq); - for i in 0..N { - println!("data: {} {} {}", x[i], y[i], 1. / w[i].sqrt()); + for ((x, y), w) in x.iter().zip(y).zip(w) { + println!("data: {} {} {}", x, y, 1. / w.sqrt()); } println!(); + let dx = (x[x.len() - 1] - x[0]) / 100.; for i in -30..130 { - let xf = x[0] + (i as f64 / 100.) * (x[N - 1] - x[0]); + let xf = x[0] + i as f64 * dx; let (yf, yf_err) = fit::linear_est(xf, c0, c1, cov00, cov01, cov11).unwrap(); println!("fit: {} {}", xf, yf); diff --git a/src/cblas.rs b/src/cblas.rs index b81f1cab..da7d15fc 100644 --- a/src/cblas.rs +++ b/src/cblas.rs @@ -2,382 +2,754 @@ // A rust binding for the GSL library by Guillaume Gomez (guillaume1.gomez@gmail.com) // +use crate::vector::Vector; + +/// Return the length of `x` as a `i32` value (to use in CBLAS calls). +#[inline] +fn len>(x: &T) -> i32 { + x.len().try_into().expect("Length must fit in `i32`") +} + +#[inline] +fn as_ptr>(x: &T) -> *const F { + x.as_slice().as_ptr() +} + +#[inline] +fn as_mut_ptr>(x: &mut T) -> *mut F { + x.as_mut_slice().as_mut_ptr() +} + +/// Return the stride of `x` as a `i32` value (to use in CBLAS calls). +#[inline] +fn stride>(x: &T) -> i32 { + x.stride().try_into().expect("Stride must fit in `i32`") +} + pub mod level1 { + #[cfg(feature = "complex")] + use num_complex::Complex; + use super::{as_mut_ptr, as_ptr, len, stride}; + use crate::vector::{check_equal_len, Vector}; + + /// Return the sum of `alpha` and the dot product of `x` and `y`. #[doc(alias = "cblas_sdsdot")] - pub fn sdsdot(N: i32, alpha: f32, x: &[f32], incx: i32, y: &[f32], incy: i32) -> f32 { - unsafe { sys::cblas_sdsdot(N, alpha, x.as_ptr(), incx, y.as_ptr(), incy) } + pub fn sdsdot>(alpha: f32, x: &T, y: &T) -> f32 { + check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); + unsafe { sys::cblas_sdsdot(len(x), alpha, as_ptr(x), stride(x), as_ptr(y), stride(y)) } } + /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_dsdot")] - pub fn dsdot(N: i32, x: &[f32], incx: i32, y: &[f32], incy: i32) -> f64 { - unsafe { sys::cblas_dsdot(N, x.as_ptr(), incx, y.as_ptr(), incy) } + pub fn dsdot>(x: &T, y: &T) -> f64 { + check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); + unsafe { sys::cblas_dsdot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } + /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_sdot")] - pub fn sdot(N: i32, x: &[f32], incx: i32, y: &[f32], incy: i32) -> f32 { - unsafe { sys::cblas_sdot(N, x.as_ptr(), incx, y.as_ptr(), incy) } + pub fn sdot>(x: &T, y: &T) -> f32 { + check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); + unsafe { sys::cblas_sdot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } + /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_ddot")] - pub fn ddot(N: i32, x: &[f64], incx: i32, y: &[f64], incy: i32) -> f64 { - unsafe { sys::cblas_ddot(N, x.as_ptr(), incx, y.as_ptr(), incy) } + pub fn ddot>(x: &T, y: &T) -> f64 { + check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); + unsafe { sys::cblas_ddot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } + #[cfg(feature = "complex")] + /// Return the unconjugated dot product between `x` and `y`, that + /// is ∑ xᵢ yᵢ. + /// + /// # Example + /// + /// ``` + /// use num_complex::Complex; + /// use rgsl::cblas::level1::cdotu; + /// let x = [Complex::new(1., 1.), Complex::new(2., 1.)]; + /// assert_eq!(cdotu(&x, &x), Complex::new(3., 6.)) + /// ``` #[doc(alias = "cblas_cdotu_sub")] - pub fn cdotu_sub(N: i32, x: &[T], incx: i32, y: &[T], incy: i32, dotu: &mut [T]) { + pub fn cdotu>>(x: &T, y: &T) -> Complex { + check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); + let mut dotu: Complex = Complex::new(0., 0.); unsafe { sys::cblas_cdotu_sub( - N, - x.as_ptr() as *const _, - incx, - y.as_ptr() as *const _, - incy, - dotu.as_mut_ptr() as *mut _, + len(x), + as_ptr(x) as *const _, + stride(x), + as_ptr(y) as *const _, + stride(y), + &mut dotu as *mut Complex as *mut _, ) } + dotu } + #[cfg(feature = "complex")] + /// Return the (conjugated) dot product between `x` and `y`, that + /// is ∑ x̅ᵢ yᵢ. + /// + /// # Example + /// + /// ``` + /// use num_complex::Complex; + /// use rgsl::cblas::level1::cdotc; + /// let x = [Complex::new(1., 1.), Complex::new(2., 1.)]; + /// assert_eq!(cdotc(&x, &x), Complex::new(7., 0.)) + /// ``` #[doc(alias = "cblas_cdotc_sub")] - pub fn cdotc_sub(N: i32, x: &[T], incx: i32, y: &[T], incy: i32, dotc: &mut [T]) { + pub fn cdotc>>(x: &T, y: &T) -> Complex { + check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); + let mut dotc: Complex = Complex::new(0., 0.); unsafe { sys::cblas_cdotc_sub( - N, - x.as_ptr() as *const _, - incx, - y.as_ptr() as *const _, - incy, - dotc.as_mut_ptr() as *mut _, + len(x), + as_ptr(x) as *const _, + stride(x), + as_ptr(y) as *const _, + stride(y), + &mut dotc as *mut Complex as *mut _, ) } + dotc } + #[cfg(feature = "complex")] + /// Return the unconjugated dot product between `x` and `y`, that + /// is ∑ xᵢ yᵢ. + /// + /// # Example + /// + /// ``` + /// use num_complex::Complex; + /// use rgsl::cblas::level1::zdotu; + /// let x = [Complex::new(1., 1.), Complex::new(2., 1.)]; + /// assert_eq!(zdotu(&x, &x), Complex::new(3., 6.)) + /// ``` #[doc(alias = "cblas_zdotu_sub")] - pub fn zdotu_sub(N: i32, x: &[T], incx: i32, y: &[T], incy: i32, dotu: &mut [T]) { + pub fn zdotu>>(x: &T, y: &T) -> Complex { + check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); + let mut dotu: Complex = Complex::new(0., 0.); unsafe { sys::cblas_zdotu_sub( - N, - x.as_ptr() as *const _, - incx, - y.as_ptr() as *const _, - incy, - dotu.as_mut_ptr() as *mut _, + len(x), + as_ptr(x) as *const _, + stride(x), + as_ptr(y) as *const _, + stride(y), + &mut dotu as *mut Complex as *mut _, ) } + dotu } + #[cfg(feature = "complex")] + /// Return the (conjugated) dot product between `x` and `y`, that + /// is ∑ x̅ᵢ yᵢ. + /// + /// # Example + /// + /// ``` + /// use num_complex::Complex; + /// use rgsl::cblas::level1::zdotc; + /// let x = [Complex::new(1., 1.), Complex::new(2., 1.)]; + /// assert_eq!(zdotc(&x, &x), Complex::new(7., 0.)) + /// ``` #[doc(alias = "cblas_zdotc_sub")] - pub fn zdotc_sub(N: i32, x: &[T], incx: i32, y: &[T], incy: i32, dotc: &mut [T]) { + pub fn zdotc>>(x: &T, y: &T) -> Complex { + check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); + let mut dotc: Complex = Complex::new(0., 0.); unsafe { sys::cblas_zdotc_sub( - N, - x.as_ptr() as *const _, - incx, - y.as_ptr() as *const _, - incy, - dotc.as_mut_ptr() as *mut _, + len(x), + as_ptr(x) as *const _, + stride(x), + as_ptr(y) as *const _, + stride(y), + &mut dotc as *mut Complex as *mut _, ) } + dotc } + /// Return the Euclidean norm of `x`. #[doc(alias = "cblas_snrm2")] - pub fn snrm2(N: i32, x: &[f32], incx: i32) -> f32 { - unsafe { sys::cblas_snrm2(N, x.as_ptr(), incx) } + pub fn snrm2>(x: &T) -> f32 { + unsafe { sys::cblas_snrm2(len(x), as_ptr(x), stride(x)) } } + /// Return the sum of the absolute values of the elements of `x` + /// (i.e., its L¹-norm). #[doc(alias = "cblas_sasum")] - pub fn sasum(N: i32, x: &[f32], incx: i32) -> f32 { - unsafe { sys::cblas_sasum(N, x.as_ptr(), incx) } + pub fn sasum>(x: &T) -> f32 { + unsafe { sys::cblas_sasum(len(x), as_ptr(x), stride(x)) } } + /// Return the Euclidean norm of `x`. #[doc(alias = "cblas_dnrm2")] - pub fn dnrm2(N: i32, x: &[f64], incx: i32) -> f64 { - unsafe { sys::cblas_dnrm2(N, x.as_ptr(), incx) } + pub fn dnrm2>(x: &T) -> f64 { + unsafe { sys::cblas_dnrm2(len(x), as_ptr(x), stride(x)) } } + /// Return the sum of the absolute values of the elements of `x` + /// (i.e., its L¹-norm). #[doc(alias = "cblas_dasum")] - pub fn dasum(N: i32, x: &[f64], incx: i32) -> f64 { - unsafe { sys::cblas_dasum(N, x.as_ptr(), incx) } + pub fn dasum>(x: &T) -> f64 { + unsafe { sys::cblas_dasum(len(x), as_ptr(x), stride(x)) } } + #[cfg(feature = "complex")] + /// Return the Euclidean norm of `x`. + /// + /// # Example + /// + /// ``` + /// use num_complex::Complex; + /// use rgsl::cblas::level1::scnrm2; + /// let x = [Complex::new(1., 1.), Complex::new(2., 1.)]; + /// assert_eq!(scnrm2(&x), 7f32.sqrt()) + /// ``` #[doc(alias = "cblas_scnrm2")] - pub fn scnrm2(N: i32, x: &[T], incx: i32) -> f32 { - unsafe { sys::cblas_scnrm2(N, x.as_ptr() as *const _, incx) } + pub fn scnrm2>>(x: &T) -> f32 { + unsafe { sys::cblas_scnrm2(len(x), as_ptr(x) as *const _, stride(x)) } } + #[cfg(feature = "complex")] + /// Return the sum of the modulus of the elements of `x` + /// (i.e., its L¹-norm). #[doc(alias = "cblas_scasum")] - pub fn scasum(N: i32, x: &[T], incx: i32) -> f32 { - unsafe { sys::cblas_scasum(N, x.as_ptr() as *const _, incx) } + pub fn scasum>>(x: &T) -> f32 { + unsafe { sys::cblas_scasum(len(x), as_ptr(x) as *const _, stride(x)) } } + #[cfg(feature = "complex")] + /// Return the Euclidean norm of `x`. + /// + /// # Example + /// + /// ``` + /// use num_complex::Complex; + /// use rgsl::cblas::level1::dznrm2; + /// let x = [Complex::new(1., 1.), Complex::new(2., 1.)]; + /// assert_eq!(dznrm2(&x), 7f64.sqrt()) + /// ``` #[doc(alias = "cblas_dznrm2")] - pub fn dznrm2(N: i32, x: &[T], incx: i32) -> f64 { - unsafe { sys::cblas_dznrm2(N, x.as_ptr() as *const _, incx) } + pub fn dznrm2>>(x: &T) -> f64 { + unsafe { sys::cblas_dznrm2(len(x), as_ptr(x) as *const _, stride(x)) } } + #[cfg(feature = "complex")] + /// Return the sum of the modulus of the elements of `x` + /// (i.e., its L¹-norm). #[doc(alias = "cblas_dzasum")] - pub fn dzasum(N: i32, x: &[T], incx: i32) -> f64 { - unsafe { sys::cblas_dzasum(N, x.as_ptr() as *const _, incx) } + pub fn dzasum>>(x: &T) -> f64 { + unsafe { sys::cblas_dzasum(len(x), as_ptr(x) as *const _, stride(x)) } } + /// Return the index of the element with maximum absolute value. #[doc(alias = "cblas_isamax")] - pub fn isamax(N: i32, x: &[f32], incx: i32) -> usize { - unsafe { sys::cblas_isamax(N, x.as_ptr(), incx) } + pub fn isamax>(x: &T) -> usize { + unsafe { sys::cblas_isamax(len(x), as_ptr(x), stride(x)) } } + /// Return the index of the element with maximum absolute value. #[doc(alias = "cblas_idamax")] - pub fn idamax(N: i32, x: &[f64], incx: i32) -> usize { - unsafe { sys::cblas_idamax(N, x.as_ptr(), incx) } + pub fn idamax>(x: &T) -> usize { + unsafe { sys::cblas_idamax(len(x), as_ptr(x), stride(x)) } } + #[cfg(feature = "complex")] + /// Return the index of the element with maximum modulus. #[doc(alias = "cblas_icamax")] - pub fn icamax(N: i32, x: &[T], incx: i32) -> usize { - unsafe { sys::cblas_icamax(N, x.as_ptr() as *const _, incx) } + pub fn icamax>>(x: &T) -> usize { + unsafe { sys::cblas_icamax(len(x), as_ptr(x) as *const _, stride(x)) } } + #[cfg(feature = "complex")] + /// Return the index of the element with maximum modulus. #[doc(alias = "cblas_izamax")] - pub fn izamax(N: i32, x: &[T], incx: i32) -> usize { - unsafe { sys::cblas_izamax(N, x.as_ptr() as *const _, incx) } + pub fn izamax>>(x: &T) -> usize { + unsafe { sys::cblas_izamax(len(x), as_ptr(x) as *const _, stride(x)) } } + /// Swap vectors `x` and `y`. #[doc(alias = "cblas_sswap")] - pub fn sswap(N: i32, x: &mut [f32], incx: i32, y: &mut [f32], incy: i32) { - unsafe { sys::cblas_sswap(N, x.as_mut_ptr(), incx, y.as_mut_ptr(), incy) } + pub fn sswap>(x: &mut T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + unsafe { sys::cblas_sswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } + /// Copy the content of `x` into `y`. #[doc(alias = "cblas_scopy")] - pub fn scopy(N: i32, x: &[f32], incx: i32, y: &mut [f32], incy: i32) { - unsafe { sys::cblas_scopy(N, x.as_ptr(), incx, y.as_mut_ptr(), incy) } + pub fn scopy>(x: &T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + unsafe { sys::cblas_scopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } + /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_saxpy")] - pub fn saxpy(N: i32, alpha: f32, x: &[f32], incx: i32, y: &mut [f32], incy: i32) { - unsafe { sys::cblas_saxpy(N, alpha, x.as_ptr(), incx, y.as_mut_ptr(), incy) } + pub fn saxpy>(alpha: f32, x: &T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + unsafe { + sys::cblas_saxpy( + len(x), + alpha, + as_ptr(x), + stride(x), + as_mut_ptr(y), + stride(y), + ) + } } + /// Swap vectors `x` and `y`. #[doc(alias = "cblas_dswap")] - pub fn dswap(N: i32, x: &mut [f64], incx: i32, y: &mut [f64], incy: i32) { - unsafe { sys::cblas_dswap(N, x.as_mut_ptr(), incx, y.as_mut_ptr(), incy) } + pub fn dswap>(x: &mut T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + unsafe { sys::cblas_dswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } + /// Copy the content of `x` into `y`. #[doc(alias = "cblas_dcopy")] - pub fn dcopy(N: i32, x: &[f64], incx: i32, y: &mut [f64], incy: i32) { - unsafe { sys::cblas_dcopy(N, x.as_ptr(), incx, y.as_mut_ptr(), incy) } + pub fn dcopy>(x: &T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + unsafe { sys::cblas_dcopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } + /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_daxpy")] - pub fn daxpy(N: i32, alpha: f64, x: &[f64], incx: i32, y: &mut [f64], incy: i32) { - unsafe { sys::cblas_daxpy(N, alpha, x.as_ptr(), incx, y.as_mut_ptr(), incy) } + pub fn daxpy>(alpha: f64, x: &T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + unsafe { + sys::cblas_daxpy( + len(x), + alpha, + as_ptr(x), + stride(x), + as_mut_ptr(y), + stride(y), + ) + } } + #[cfg(feature = "complex")] + /// Swap vectors `x` and `y`. #[doc(alias = "cblas_cswap")] - pub fn cswap(N: i32, x: &mut [T], incx: i32, y: &mut [T], incy: i32) { + pub fn cswap>>(x: &mut T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_cswap( - N, - x.as_mut_ptr() as *mut _, - incx, - y.as_mut_ptr() as *mut _, - incy, + len(x), + as_mut_ptr(x) as *mut _, + stride(x), + as_mut_ptr(y) as *mut _, + stride(y), ) } } + #[cfg(feature = "complex")] + /// Copy the content of `x` into `y`. #[doc(alias = "cblas_ccopy")] - pub fn ccopy(N: i32, x: &[T], incx: i32, y: &mut [T], incy: i32) { + pub fn ccopy>>(x: &T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_ccopy( - N, - x.as_ptr() as *const _, - incx, - y.as_mut_ptr() as *mut _, - incy, + len(x), + as_ptr(x) as *const _, + stride(x), + as_mut_ptr(y) as *mut _, + stride(y), ) } } + #[cfg(feature = "complex")] + /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_caxpy")] - pub fn caxpy(N: i32, alpha: &[T], x: &[T], incx: i32, y: &mut [T], incy: i32) { + pub fn caxpy>>(alpha: &Complex, x: &T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_caxpy( - N, - alpha.as_ptr() as *const _, - x.as_ptr() as *const _, - incx, - y.as_mut_ptr() as *mut _, - incy, + len(x), + alpha as *const Complex as *const _, + as_ptr(x) as *const _, + stride(x), + as_mut_ptr(y) as *mut _, + stride(y), ) } } + #[cfg(feature = "complex")] + /// Swap vectors `x` and `y`. #[doc(alias = "cblas_zswap")] - pub fn zswap(N: i32, x: &mut [T], incx: i32, y: &mut [T], incy: i32) { + pub fn zswap>>(x: &mut T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zswap( - N, - x.as_mut_ptr() as *mut _, - incx, - y.as_mut_ptr() as *mut _, - incy, + len(x), + as_mut_ptr(x) as *mut _, + stride(x), + as_mut_ptr(y) as *mut _, + stride(y), ) } } + #[cfg(feature = "complex")] + /// Copy the content of `x` into `y`. #[doc(alias = "cblas_zcopy")] - pub fn zcopy(N: i32, x: &[T], incx: i32, y: &mut [T], incy: i32) { + pub fn zcopy>>(x: &T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zcopy( - N, - x.as_ptr() as *const _, - incx, - y.as_mut_ptr() as *mut _, - incy, + len(x), + as_ptr(x) as *const _, + stride(x), + as_mut_ptr(y) as *mut _, + stride(y), ) } } + #[cfg(feature = "complex")] + /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_zaxpy")] - pub fn zaxpy(N: i32, alpha: &[T], x: &[T], incx: i32, y: &mut [T], incy: i32) { + pub fn zaxpy>>(alpha: &Complex, x: &T, y: &mut T) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zaxpy( - N, - alpha.as_ptr() as *const _, - x.as_ptr() as *const _, - incx, - y.as_mut_ptr() as *mut _, - incy, + len(x), + alpha as *const Complex as *const _, + as_ptr(x) as *const _, + stride(x), + as_mut_ptr(y) as *mut _, + stride(y), ) } } + /// Given the Cartesian coordinates (`a`, `b`), returns + /// (c, s, r, z) such that + /// + /// ⎧c s⎫ ⎧a⎫ = ⎧r⎫ + /// ⎩s c⎭ ⎩b⎭ ⎩0⎭ + /// + /// The value of z is defined such that if |`a`| > |`b`|, z is s; + /// otherwise if c ≠ 0, z is 1/c; otherwise z is 1. #[doc(alias = "cblas_srotg")] - pub fn srotg(a: &mut [f32], b: &mut [f32], c: &mut [f32], s: &mut [f32]) { + pub fn srotg(a: f32, b: f32) -> (f32, f32, f32, f32) { + let mut c = f32::NAN; + let mut s = f32::NAN; + let mut r = a; + let mut z = b; unsafe { sys::cblas_srotg( - a.as_mut_ptr(), - b.as_mut_ptr(), - c.as_mut_ptr(), - s.as_mut_ptr(), - ) - } - } - + &mut r as *mut _, + &mut z as *mut _, + &mut c as *mut _, + &mut s as *mut _, + ) + } + (c, s, r, z) + } + + /// Modified matrix transformation (for the mathematical field `F`). + #[derive(Clone, Copy)] + pub enum H { + /// Specify that H is the matrix + /// + /// ⎧`h11` `h12`⎫ + /// ⎩`h21` `h22`⎭ + Full { + h11: F, + h21: F, + h12: F, + h22: F, + }, + /// Specify that H is the matrix + /// + /// ⎧1.0 `h12`⎫ + /// ⎩`h21` 1.0⎭ + OffDiag { + h21: F, + h12: F, + }, + /// Specify that H is the matrix + /// + /// ⎧`h11` 1.0⎫ + /// ⎩-1.0 `h22`⎭ + Diag { + h11: F, + h22: F, + }, + Id, + } + + /// Given Cartesian coordinates (`x1`, `x2`), return the + /// transformation matrix H that zeros the second component or the + /// vector (`x1` √`d1`, `x2` √`d2`): + /// + /// H ⎧`x1` √`d1`⎫ = ⎧y1⎫ + /// ⎩`x2` √`d2`⎭ ⎩0.⎭ + /// + /// The second component of the return value is `y1`. #[doc(alias = "cblas_srotmg")] - pub fn srotmg(d1: &mut [f32], d2: &mut [f32], b1: &mut [f32], b2: f32, P: &mut [f32]) { + pub fn srotmg(mut d1: f32, mut d2: f32, mut x1: f32, x2: f32) -> (H, f32) { + let mut h: [f32; 5] = [0.; 5]; unsafe { sys::cblas_srotmg( - d1.as_mut_ptr(), - d2.as_mut_ptr(), - b1.as_mut_ptr(), - b2, - P.as_mut_ptr(), + &mut d1 as *mut _, + &mut d2 as *mut _, + &mut x1 as *mut _, + x2, + &mut h as *mut _, + ) + } + let h = match h[0] { + -1.0 => H::Full { + h11: h[1], + h21: h[2], + h12: h[3], + h22: h[4], + }, + 0.0 => H::OffDiag { + h21: h[2], + h12: h[3], + }, + 1.0 => H::Diag { + h11: h[1], + h22: h[4], + }, + -2.0 => H::Id, + _ => unreachable!("srotmg: incorrect flag value"), + }; + (h, x1) + } + + /// Apply plane rotation. More specifically, perform the + /// following transformation in place : + /// + /// ⎧`x`ᵢ⎫ = ⎧`c` `s`⎫ ⎧`x`ᵢ⎫ + /// ⎩`y`ᵢ⎭ ⎩-`s` `c`⎭ ⎩`y`ᵢ⎭ + /// + /// for all indices i. + #[doc(alias = "cblas_srot")] + pub fn srot>(x: &mut T, y: &mut T, c: f32, s: f32) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + unsafe { + sys::cblas_srot( + len(x), + as_mut_ptr(x), + stride(x), + as_mut_ptr(y), + stride(y), + c, + s, ) } } - #[doc(alias = "cblas_srot")] - pub fn srot(N: i32, x: &mut [f32], incx: i32, y: &mut [f32], incy: i32, c: f32, s: f32) { - unsafe { sys::cblas_srot(N, x.as_mut_ptr(), incx, y.as_mut_ptr(), incy, c, s) } - } - + /// Apply the matrix rotation `h` to `x`, `y`. + /// + /// ⎧`x`ᵢ⎫ = `h` ⎧`x`ᵢ⎫ + /// ⎩`y`ᵢ⎭ ⎩`y`ᵢ⎭ + /// + /// for all indices i. #[doc(alias = "cblas_srotm")] - pub fn srotm(N: i32, x: &mut [f32], incx: i32, y: &mut [f32], incy: i32, p: &[f32]) { - unsafe { sys::cblas_srotm(N, x.as_mut_ptr(), incx, y.as_mut_ptr(), incy, p.as_ptr()) } + pub fn srotm>(x: &mut T, y: &mut T, h: H) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + let p = match h { + H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], + H::OffDiag { h21, h12 } => [0.0, 1., h21, h12, 1.], + H::Diag { h11, h22 } => [1.0, h11, -1., 1., h22], + H::Id => [-2.0, 1., 0., 0., 1.], + }; + unsafe { + sys::cblas_srotm( + len(x), + as_mut_ptr(x), + stride(x), + as_mut_ptr(y), + stride(y), + &p as *const _, + ) + } } + /// Given the Cartesian coordinates (`a`, `b`), returns + /// (c, s, r, z) such that + /// + /// ⎧c s⎫ ⎧a⎫ = ⎧r⎫ + /// ⎩s c⎭ ⎩b⎭ ⎩0⎭ + /// + /// The value of z is defined such that if |`a`| > |`b`|, z is s; + /// otherwise if c ≠ 0, z is 1/c; otherwise z is 1. #[doc(alias = "cblas_drotg")] - pub fn drotg(a: &mut [f64], b: &mut [f64], c: &mut [f64], s: &mut [f64]) { + pub fn drotg(a: f64, b: f64) -> (f64, f64, f64, f64) { + let mut c = f64::NAN; + let mut s = f64::NAN; + let mut r = a; + let mut z = b; unsafe { sys::cblas_drotg( - a.as_mut_ptr(), - b.as_mut_ptr(), - c.as_mut_ptr(), - s.as_mut_ptr(), + &mut r as *mut _, + &mut z as *mut _, + &mut c as *mut _, + &mut s as *mut _, ) } + (c, s, r, z) } + /// Given Cartesian coordinates (`x1`, `x2`), return the + /// transformation matrix H that zeros the second component or the + /// vector (`x1` √`d1`, `x2` √`d2`): + /// + /// H ⎧`x1` √`d1`⎫ = ⎧y1⎫ + /// ⎩`x2` √`d2`⎭ ⎩0.⎭ + /// + /// The second component of the return value is `y1`. #[doc(alias = "cblas_drotmg")] - pub fn drotmg(d1: &mut [f64], d2: &mut [f64], b1: &mut [f64], b2: f64, P: &mut [f64]) { + pub fn drotmg(mut d1: f64, mut d2: f64, mut x1: f64, x2: f64) -> (H, f64) { + let mut h: [f64; 5] = [0.; 5]; unsafe { sys::cblas_drotmg( - d1.as_mut_ptr(), - d2.as_mut_ptr(), - b1.as_mut_ptr(), - b2, - P.as_mut_ptr(), + &mut d1 as *mut _, + &mut d2 as *mut _, + &mut x1 as *mut _, + x2, + &mut h as *mut _, + ) + } + let h = match h[0] { + -1.0 => H::Full { + h11: h[1], + h21: h[2], + h12: h[3], + h22: h[4], + }, + 0.0 => H::OffDiag { + h21: h[2], + h12: h[3], + }, + 1.0 => H::Diag { + h11: h[1], + h22: h[4], + }, + -2.0 => H::Id, + _ => unreachable!("srotmg: incorrect flag value"), + }; + (h, x1) + } + + /// Apply plane rotation. More specifically, perform the + /// following transformation in place : + /// + /// ⎧`x`ᵢ⎫ = ⎧`c` `s`⎫ ⎧`x`ᵢ⎫ + /// ⎩`y`ᵢ⎭ ⎩-`s` `c`⎭ ⎩`y`ᵢ⎭ + /// + /// for all indices i. + #[doc(alias = "cblas_drot")] + pub fn drot>(x: &mut T, y: &mut T, c: f64, s: f64) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + unsafe { + sys::cblas_drot( + len(x), + as_mut_ptr(x), + stride(x), + as_mut_ptr(y), + stride(y), + c, + s, ) } } - #[doc(alias = "cblas_drot")] - pub fn drot(N: i32, x: &mut [f64], incx: i32, y: &mut [f64], incy: i32, c: f64, s: f64) { - unsafe { sys::cblas_drot(N, x.as_mut_ptr(), incx, y.as_mut_ptr(), incy, c, s) } - } - + /// Apply the matrix rotation `h` to `x`, `y`. + /// + /// ⎧`x`ᵢ⎫ = `h` ⎧`x`ᵢ⎫ + /// ⎩`y`ᵢ⎭ ⎩`y`ᵢ⎭ + /// + /// for all indices i. #[doc(alias = "cblas_drotm")] - pub fn drotm(N: i32, x: &mut [f64], incx: i32, y: &mut [f64], incy: i32, p: &[f64]) { - unsafe { sys::cblas_drotm(N, x.as_mut_ptr(), incx, y.as_mut_ptr(), incy, p.as_ptr()) } + pub fn drotm>(x: &mut T, y: &mut T, h: H) { + check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); + let p = match h { + H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], + H::OffDiag { h21, h12 } => [0.0, 1., h21, h12, 1.], + H::Diag { h11, h22 } => [1.0, h11, -1., 1., h22], + H::Id => [-2.0, 1., 0., 0., 1.], + }; + unsafe { + sys::cblas_drotm( + len(x), + as_mut_ptr(x), + stride(x), + as_mut_ptr(y), + stride(y), + &p as *const _, + ) + } } - /// Multiple each element of a matrix/vector by a constant. - /// - /// __Postcondition__: Every incX'th element of X has been multiplied by a factor of alpha - /// - /// __Parameters__: - /// - /// * N : number of elements in x to scale - /// * alpha : factor to scale by - /// * X : pointer to the vector/matrix data - /// * incx : Amount to increment counter after each scaling, ie incX=2 mean to scale elements {1,3,...} - /// - /// Note that the allocated length of X must be incX*N-1 as N indicates the number of scaling operations to perform. + /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_sscal")] - pub fn sscal(N: i32, alpha: f32, x: &mut [f32], incx: i32) { - unsafe { sys::cblas_sscal(N, alpha, x.as_mut_ptr(), incx) } + pub fn sscal>(alpha: f32, x: &mut T) { + unsafe { sys::cblas_sscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } - /// Multiple each element of a matrix/vector by a constant. + /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_dscal")] - pub fn dscal(N: i32, alpha: f64, x: &mut [f64], incx: i32) { - unsafe { sys::cblas_dscal(N, alpha, x.as_mut_ptr(), incx) } + pub fn dscal>(alpha: f64, x: &mut T) { + unsafe { sys::cblas_dscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } - /// Multiple each element of a matrix/vector by a constant. + #[cfg(feature = "complex")] + /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_cscal")] - pub fn cscal(N: i32, alpha: &[T], x: &mut [T], incx: i32) { + pub fn cscal>>(alpha: &Complex, x: &mut T) { unsafe { sys::cblas_cscal( - N, - alpha.as_ptr() as *const _, - x.as_mut_ptr() as *mut _, - incx, + len(x), + alpha as *const Complex as *const _, + as_mut_ptr(x) as *mut _, + stride(x), ) } } - /// Multiple each element of a matrix/vector by a constant. + #[cfg(feature = "complex")] + /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_zscal")] - pub fn zscal(N: i32, alpha: &[T], x: &mut [T], incx: i32) { + pub fn zscal>>(alpha: &Complex, x: &mut T) { unsafe { sys::cblas_zscal( - N, - alpha.as_ptr() as *const _, - x.as_mut_ptr() as *mut _, - incx, + len(x), + alpha as *const Complex as *const _, + as_mut_ptr(x) as *mut _, + stride(x), ) } } - /// Multiple each element of a matrix/vector by a constant. + #[cfg(feature = "complex")] + /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_csscal")] - pub fn csscal(N: i32, alpha: f32, x: &mut [T], incx: i32) { - unsafe { sys::cblas_csscal(N, alpha, x.as_mut_ptr() as *mut _, incx) } + pub fn csscal>>(alpha: f32, x: &mut T) { + unsafe { sys::cblas_csscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } + #[cfg(feature = "complex")] /// Multiple each element of a matrix/vector by a constant. #[doc(alias = "cblas_zdscal")] - pub fn zdscal(N: i32, alpha: f64, x: &mut [T], incx: i32) { - unsafe { sys::cblas_zdscal(N, alpha, x.as_mut_ptr() as *mut _, incx) } + pub fn zdscal>>(alpha: f64, x: &mut T) { + unsafe { sys::cblas_zdscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } } diff --git a/src/fit.rs b/src/fit.rs index 2ee19c85..58c6d321 100644 --- a/src/fit.rs +++ b/src/fit.rs @@ -8,33 +8,40 @@ Linear Regression The functions described in this section can be used to perform least-squares fits to a straight line model, Y(c,x) = c_0 + c_1 x. !*/ -use crate::Value; +use crate::{ + vector::{check_equal_len, Vector}, + Value, +}; -/// This function computes the best-fit linear regression coefficients (c0,c1) of the model -/// Y = c_0 + c_1 X for the dataset (x, y), two vectors of length n with strides xstride and -/// ystride. +/// This function computes the best-fit linear regression coefficients +/// (c0, c1) of the model Y = c_0 + c_1 X for the dataset (`x`, `y`), +/// two vectors of the same length (possibly with strides). /// -/// The errors on y are assumed unknown so the variance-covariance matrix for the parameters -/// (c0, c1) is estimated from the scatter of the points around the best-fit line and returned via -/// the parameters (cov00, cov01, cov11). +/// The errors on `y` are assumed unknown so the variance-covariance +/// matrix for the parameters (c0, c1) is estimated from the scatter +/// of the points around the best-fit line and returned via the +/// parameters (cov00, cov01, cov11). /// -/// The sum of squares of the residuals from the best-fit line is returned in sumsq. Note: the -/// correlation coefficient of the data can be computed using gsl_stats_correlation (see +/// The sum of squares of the residuals from the best-fit line is +/// returned in sumsq. Note: the correlation coefficient of the data +/// can be computed using gsl_stats_correlation (see /// [`Correlation`](http://www.gnu.org/software/gsl/manual/html_node/Correlation.html#Correlation)), /// it does not depend on the fit. /// /// Returns `(c0, c1, cov00, cov01, cov11, sumsq)`. +/// +/// # Example +/// +/// ``` +/// use rgsl::fit; +/// let (c0, c1, _, _, _, _) = fit::linear(&[0., 1.], &[0., 1.])?; +/// assert_eq!(c0, 0.); +/// assert_eq!(c1, 1.); +/// # Ok::<(), rgsl::Value>(()) +/// ``` #[doc(alias = "gsl_fit_linear")] -pub fn linear( - x: &[f64], - xstride: usize, - y: &[f64], - ystride: usize, - n: usize, -) -> Result<(f64, f64, f64, f64, f64, f64), Value> { - if (n - 1) * xstride >= x.len() || (n - 1) * ystride >= y.len() { - return Err(Value::Invalid); - } +pub fn linear>(x: &T, y: &T) -> Result<(f64, f64, f64, f64, f64, f64), Value> { + check_equal_len(x, y)?; let mut c0 = 0.; let mut c1 = 0.; let mut cov00 = 0.; @@ -42,12 +49,12 @@ pub fn linear( let mut cov11 = 0.; let mut sumsq = 0.; let ret = unsafe { - sys::gsl_fit_linear( - x.as_ptr(), - xstride, - y.as_ptr(), - ystride, - n, + ::sys::gsl_fit_linear( + x.as_slice().as_ptr(), + x.stride(), + y.as_slice().as_ptr(), + y.stride(), + x.len(), &mut c0, &mut c1, &mut cov00, @@ -59,11 +66,12 @@ pub fn linear( result_handler!(ret, (c0, c1, cov00, cov01, cov11, sumsq)) } -/// This function computes the best-fit linear regression coefficients (c0,c1) of the model -/// Y = c_0 + c_1 X for the weighted dataset (x, y), two vectors of length n with strides xstride -/// and ystride. +/// This function computes the best-fit linear regression coefficients +/// (c0, c1) of the model Y = c_0 + c_1 X for the weighted dataset +/// (`x`, `y`), two vectors of the same length (possibly with strides). /// -/// The vector w, of length n and stride wstride, specifies the weight of each datapoint. +/// The vector `w`, of the same length as `x` and `y`, specifies the +/// weight of each datapoint. /// /// The weight is the reciprocal of the variance for each datapoint in y. /// @@ -73,19 +81,13 @@ pub fn linear( /// /// Returns `(c0, c1, cov00, cov01, cov11, chisq)`. #[doc(alias = "gsl_fit_wlinear")] -pub fn wlinear( - x: &[f64], - xstride: usize, - w: &[f64], - wstride: usize, - y: &[f64], - ystride: usize, - n: usize, +pub fn wlinear>( + x: &T, + w: &T, + y: &T, ) -> Result<(f64, f64, f64, f64, f64, f64), Value> { - if (n - 1) * xstride >= x.len() || (n - 1) * wstride >= w.len() || (n - 1) * ystride >= y.len() - { - return Err(Value::Invalid); - } + check_equal_len(x, y)?; + check_equal_len(x, w)?; let mut c0 = 0.; let mut c1 = 0.; let mut cov00 = 0.; @@ -93,14 +95,14 @@ pub fn wlinear( let mut cov11 = 0.; let mut chisq = 0.; let ret = unsafe { - sys::gsl_fit_wlinear( - x.as_ptr(), - xstride, - w.as_ptr(), - wstride, - y.as_ptr(), - ystride, - n, + ::sys::gsl_fit_wlinear( + x.as_slice().as_ptr(), + x.stride(), + w.as_slice().as_ptr(), + w.stride(), + y.as_slice().as_ptr(), + y.stride(), + x.len(), &mut c0, &mut c1, &mut cov00, @@ -141,26 +143,18 @@ pub fn linear_est( /// /// Returns `(c1, cov11, sumsq)`. #[doc(alias = "gsl_fit_mul")] -pub fn mul( - x: &[f64], - xstride: usize, - y: &[f64], - ystride: usize, - n: usize, -) -> Result<(f64, f64, f64), Value> { - if (n - 1) * xstride >= x.len() || (n - 1) * ystride >= y.len() { - return Err(Value::Invalid); - } +pub fn mul>(x: &T, y: &T) -> Result<(f64, f64, f64), Value> { + check_equal_len(x, y)?; let mut c1 = 0.; let mut cov11 = 0.; let mut sumsq = 0.; let ret = unsafe { sys::gsl_fit_mul( - x.as_ptr(), - xstride, - y.as_ptr(), - ystride, - n, + x.as_slice().as_ptr(), + x.stride(), + y.as_slice().as_ptr(), + y.stride(), + x.len(), &mut c1, &mut cov11, &mut sumsq, @@ -171,31 +165,21 @@ pub fn mul( /// Returns `(c1, cov11, sumsq)`. #[doc(alias = "gsl_fit_wmul")] -pub fn wmul( - x: &[f64], - xstride: usize, - w: &[f64], - wstride: usize, - y: &[f64], - ystride: usize, - n: usize, -) -> Result<(f64, f64, f64), Value> { - if (n - 1) * xstride >= x.len() || (n - 1) * wstride >= w.len() || (n - 1) * ystride >= y.len() - { - return Err(Value::Invalid); - } +pub fn wmul>(x: &T, w: &T, y: &T) -> Result<(f64, f64, f64), Value> { + check_equal_len(x, y)?; + check_equal_len(x, w)?; let mut c1 = 0.; let mut cov11 = 0.; let mut sumsq = 0.; let ret = unsafe { sys::gsl_fit_wmul( - x.as_ptr(), - xstride, - w.as_ptr(), - wstride, - y.as_ptr(), - ystride, - n, + x.as_slice().as_ptr(), + x.stride(), + w.as_slice().as_ptr(), + w.stride(), + y.as_slice().as_ptr(), + y.stride(), + x.len(), &mut c1, &mut cov11, &mut sumsq, diff --git a/src/types/vector.rs b/src/types/vector.rs index 536a290e..52ee1ca7 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -38,6 +38,26 @@ use std::marker::PhantomData; use paste::paste; +#[cfg(feature = "complex")] +extern crate num_complex; +#[cfg(feature = "complex")] +use self::num_complex::Complex; + +/// Trait implemented by types that are considered vectors by this crate. +/// Elements of the vector are of type `F` (`f32` or `f64`). +pub trait Vector { + /// Return the number of elements in the vector. + fn len(&self) -> usize; + /// The distance in the slice between two consecutive elements. + fn stride(&self) -> usize; + /// Return a reference to the underlying slice. Note that the + /// `i`th element of the vector is the `i * stride` element in the + /// slice. + fn as_slice(&self) -> &[F]; + /// As [`as_slice`] but mutable. + fn as_mut_slice(&mut self) -> &mut [F]; +} + macro_rules! gsl_vec { ($rust_name:ident, $name:ident, $rust_ty:ident) => ( paste! { @@ -556,6 +576,26 @@ impl<'a> [<$rust_name View>]<'a> { } } // end of impl block + impl Vector<$rust_ty> for $rust_name { + fn len(&self) -> usize { + $rust_name::len(self) + } + fn stride(&self) -> usize { + let ptr = self.unwrap_shared(); + if ptr.is_null() { + 1 + } else { + unsafe { (*ptr).stride } + } + } + fn as_slice(&self) -> &[$rust_ty] { + $rust_name::as_slice(self).unwrap_or(&[]) + } + fn as_mut_slice(&mut self) -> &mut [$rust_ty] { + $rust_name::as_slice_mut(self).unwrap_or(&mut []) + } + } + } // end of paste! block ); // end of gsl_vec macro } @@ -564,3 +604,45 @@ gsl_vec!(VectorF32, gsl_vector_float, f32); gsl_vec!(VectorF64, gsl_vector, f64); gsl_vec!(VectorI32, gsl_vector_int, i32); gsl_vec!(VectorU32, gsl_vector_uint, u32); + +// Implement the `Vector` trait on standard vectors. + +macro_rules! impl_AsRef { + ($ty: ty) => { + impl Vector<$ty> for T + where + T: AsRef<[$ty]> + AsMut<[$ty]>, + { + fn len(&self) -> usize { + self.as_ref().len() + } + fn stride(&self) -> usize { + 1 + } + fn as_slice(&self) -> &[$ty] { + self.as_ref() + } + fn as_mut_slice(&mut self) -> &mut [$ty] { + self.as_mut() + } + } + }; +} + +impl_AsRef!(f32); +impl_AsRef!(f64); +#[cfg(feature = "complex")] +impl_AsRef!(Complex); +#[cfg(feature = "complex")] +impl_AsRef!(Complex); + +#[inline] +pub(crate) fn check_equal_len(x: &T, y: &T) -> Result<(), Value> +where + T: Vector, +{ + if x.len() != y.len() { + return Err(Value::Invalid); + } + Ok(()) +} From 3f7f4d510306c40c002f8da242acc07b454d593b Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 25 Mar 2024 00:33:53 +0100 Subject: [PATCH 02/28] Minor reformatting --- src/cblas.rs | 4 ++-- src/types/complex.rs | 3 +-- src/types/vector_complex.rs | 7 ++++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/cblas.rs b/src/cblas.rs index da7d15fc..c24e171d 100644 --- a/src/cblas.rs +++ b/src/cblas.rs @@ -27,10 +27,10 @@ fn stride>(x: &T) -> i32 { } pub mod level1 { - #[cfg(feature = "complex")] - use num_complex::Complex; use super::{as_mut_ptr, as_ptr, len, stride}; use crate::vector::{check_equal_len, Vector}; + #[cfg(feature = "complex")] + use num_complex::Complex; /// Return the sum of `alpha` and the dot product of `x` and `y`. #[doc(alias = "cblas_sdsdot")] diff --git a/src/types/complex.rs b/src/types/complex.rs index 21f1e6ae..a41d013a 100644 --- a/src/types/complex.rs +++ b/src/types/complex.rs @@ -4,8 +4,7 @@ // TODO : port to Rust type : http://doc.rust-lang.org/num/complex/struct.Complex.html -use std::fmt; -use std::fmt::{Debug, Formatter}; +use std::fmt::{self, Debug, Formatter}; #[doc(hidden)] #[allow(clippy::upper_case_acronyms)] diff --git a/src/types/vector_complex.rs b/src/types/vector_complex.rs index 0658db6b..e71c2a5b 100644 --- a/src/types/vector_complex.rs +++ b/src/types/vector_complex.rs @@ -5,9 +5,10 @@ use crate::ffi::FFI; use crate::Value; use paste::paste; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::marker::PhantomData; +use std::{ + fmt::{self, Debug, Formatter}, + marker::PhantomData, +}; macro_rules! gsl_vec_complex { ($rust_name:ident, $name:ident, $complex:ident, $rust_ty:ident) => { From 619fb48983cff80d2f1dbb82e8f26585068376e9 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 25 Mar 2024 00:45:55 +0100 Subject: [PATCH 03/28] Appease clippy --- src/cblas.rs | 44 ++++++++++++++++++++++++++------------------ src/types/vector.rs | 1 + 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/cblas.rs b/src/cblas.rs index c24e171d..ff1d4f76 100644 --- a/src/cblas.rs +++ b/src/cblas.rs @@ -511,23 +511,27 @@ pub mod level1 { &mut h as *mut _, ) } - let h = match h[0] { - -1.0 => H::Full { + let h = if h[0] == -1.0 { + H::Full { h11: h[1], h21: h[2], h12: h[3], h22: h[4], - }, - 0.0 => H::OffDiag { + } + } else if h[0] == 0.0 { + H::OffDiag { h21: h[2], h12: h[3], - }, - 1.0 => H::Diag { + } + } else if h[0] == 1.0 { + H::Diag { h11: h[1], h22: h[4], - }, - -2.0 => H::Id, - _ => unreachable!("srotmg: incorrect flag value"), + } + } else if h[0] == -2.0 { + H::Id + } else { + unreachable!("srotmg: incorrect flag value") }; (h, x1) } @@ -627,23 +631,27 @@ pub mod level1 { &mut h as *mut _, ) } - let h = match h[0] { - -1.0 => H::Full { + let h = if h[0] == -1.0 { + H::Full { h11: h[1], h21: h[2], h12: h[3], h22: h[4], - }, - 0.0 => H::OffDiag { + } + } else if h[0] == 0.0 { + H::OffDiag { h21: h[2], h12: h[3], - }, - 1.0 => H::Diag { + } + } else if h[0] == 1.0 { + H::Diag { h11: h[1], h22: h[4], - }, - -2.0 => H::Id, - _ => unreachable!("srotmg: incorrect flag value"), + } + } else if h[0] == -2.0 { + H::Id + } else { + unreachable!("srotmg: incorrect flag value") }; (h, x1) } diff --git a/src/types/vector.rs b/src/types/vector.rs index 52ee1ca7..31d87953 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -43,6 +43,7 @@ extern crate num_complex; #[cfg(feature = "complex")] use self::num_complex::Complex; +#[allow(clippy::len_without_is_empty)] /// Trait implemented by types that are considered vectors by this crate. /// Elements of the vector are of type `F` (`f32` or `f64`). pub trait Vector { From 0f27fe4acd6197292d72067c2e47c2585100df7b Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 25 Mar 2024 10:23:26 +0100 Subject: [PATCH 04/28] Don't require Vector types to be sized (e.g. accept slices) --- src/cblas.rs | 104 ++++++++++++++++++++++++------------------- src/fit.rs | 11 +++-- src/types/complex.rs | 1 + src/types/vector.rs | 4 +- 4 files changed, 68 insertions(+), 52 deletions(-) diff --git a/src/cblas.rs b/src/cblas.rs index ff1d4f76..3d89736d 100644 --- a/src/cblas.rs +++ b/src/cblas.rs @@ -6,23 +6,23 @@ use crate::vector::Vector; /// Return the length of `x` as a `i32` value (to use in CBLAS calls). #[inline] -fn len>(x: &T) -> i32 { +fn len + ?Sized>(x: &T) -> i32 { x.len().try_into().expect("Length must fit in `i32`") } #[inline] -fn as_ptr>(x: &T) -> *const F { +fn as_ptr + ?Sized>(x: &T) -> *const F { x.as_slice().as_ptr() } #[inline] -fn as_mut_ptr>(x: &mut T) -> *mut F { +fn as_mut_ptr + ?Sized>(x: &mut T) -> *mut F { x.as_mut_slice().as_mut_ptr() } /// Return the stride of `x` as a `i32` value (to use in CBLAS calls). #[inline] -fn stride>(x: &T) -> i32 { +fn stride + ?Sized>(x: &T) -> i32 { x.stride().try_into().expect("Stride must fit in `i32`") } @@ -34,28 +34,28 @@ pub mod level1 { /// Return the sum of `alpha` and the dot product of `x` and `y`. #[doc(alias = "cblas_sdsdot")] - pub fn sdsdot>(alpha: f32, x: &T, y: &T) -> f32 { + pub fn sdsdot + ?Sized>(alpha: f32, x: &T, y: &T) -> f32 { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); unsafe { sys::cblas_sdsdot(len(x), alpha, as_ptr(x), stride(x), as_ptr(y), stride(y)) } } /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_dsdot")] - pub fn dsdot>(x: &T, y: &T) -> f64 { + pub fn dsdot + ?Sized>(x: &T, y: &T) -> f64 { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); unsafe { sys::cblas_dsdot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_sdot")] - pub fn sdot>(x: &T, y: &T) -> f32 { + pub fn sdot + ?Sized>(x: &T, y: &T) -> f32 { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); unsafe { sys::cblas_sdot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } /// Return the dot product of `x` and `y`. #[doc(alias = "cblas_ddot")] - pub fn ddot>(x: &T, y: &T) -> f64 { + pub fn ddot + ?Sized>(x: &T, y: &T) -> f64 { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); unsafe { sys::cblas_ddot(len(x), as_ptr(x), stride(x), as_ptr(y), stride(y)) } } @@ -73,7 +73,10 @@ pub mod level1 { /// assert_eq!(cdotu(&x, &x), Complex::new(3., 6.)) /// ``` #[doc(alias = "cblas_cdotu_sub")] - pub fn cdotu>>(x: &T, y: &T) -> Complex { + pub fn cdotu(x: &T, y: &T) -> Complex + where + T: Vector> + ?Sized, + { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); let mut dotu: Complex = Complex::new(0., 0.); unsafe { @@ -102,7 +105,10 @@ pub mod level1 { /// assert_eq!(cdotc(&x, &x), Complex::new(7., 0.)) /// ``` #[doc(alias = "cblas_cdotc_sub")] - pub fn cdotc>>(x: &T, y: &T) -> Complex { + pub fn cdotc(x: &T, y: &T) -> Complex + where + T: Vector> + ?Sized, + { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); let mut dotc: Complex = Complex::new(0., 0.); unsafe { @@ -131,7 +137,10 @@ pub mod level1 { /// assert_eq!(zdotu(&x, &x), Complex::new(3., 6.)) /// ``` #[doc(alias = "cblas_zdotu_sub")] - pub fn zdotu>>(x: &T, y: &T) -> Complex { + pub fn zdotu(x: &T, y: &T) -> Complex + where + T: Vector> + ?Sized, + { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); let mut dotu: Complex = Complex::new(0., 0.); unsafe { @@ -160,7 +169,10 @@ pub mod level1 { /// assert_eq!(zdotc(&x, &x), Complex::new(7., 0.)) /// ``` #[doc(alias = "cblas_zdotc_sub")] - pub fn zdotc>>(x: &T, y: &T) -> Complex { + pub fn zdotc(x: &T, y: &T) -> Complex + where + T: Vector> + ?Sized, + { check_equal_len(x, y).expect("The length of `x` and `y` must be equal"); let mut dotc: Complex = Complex::new(0., 0.); unsafe { @@ -178,27 +190,27 @@ pub mod level1 { /// Return the Euclidean norm of `x`. #[doc(alias = "cblas_snrm2")] - pub fn snrm2>(x: &T) -> f32 { + pub fn snrm2 + ?Sized>(x: &T) -> f32 { unsafe { sys::cblas_snrm2(len(x), as_ptr(x), stride(x)) } } /// Return the sum of the absolute values of the elements of `x` /// (i.e., its L¹-norm). #[doc(alias = "cblas_sasum")] - pub fn sasum>(x: &T) -> f32 { + pub fn sasum + ?Sized>(x: &T) -> f32 { unsafe { sys::cblas_sasum(len(x), as_ptr(x), stride(x)) } } /// Return the Euclidean norm of `x`. #[doc(alias = "cblas_dnrm2")] - pub fn dnrm2>(x: &T) -> f64 { + pub fn dnrm2 + ?Sized>(x: &T) -> f64 { unsafe { sys::cblas_dnrm2(len(x), as_ptr(x), stride(x)) } } /// Return the sum of the absolute values of the elements of `x` /// (i.e., its L¹-norm). #[doc(alias = "cblas_dasum")] - pub fn dasum>(x: &T) -> f64 { + pub fn dasum + ?Sized>(x: &T) -> f64 { unsafe { sys::cblas_dasum(len(x), as_ptr(x), stride(x)) } } @@ -214,7 +226,7 @@ pub mod level1 { /// assert_eq!(scnrm2(&x), 7f32.sqrt()) /// ``` #[doc(alias = "cblas_scnrm2")] - pub fn scnrm2>>(x: &T) -> f32 { + pub fn scnrm2> + ?Sized>(x: &T) -> f32 { unsafe { sys::cblas_scnrm2(len(x), as_ptr(x) as *const _, stride(x)) } } @@ -222,7 +234,7 @@ pub mod level1 { /// Return the sum of the modulus of the elements of `x` /// (i.e., its L¹-norm). #[doc(alias = "cblas_scasum")] - pub fn scasum>>(x: &T) -> f32 { + pub fn scasum> + ?Sized>(x: &T) -> f32 { unsafe { sys::cblas_scasum(len(x), as_ptr(x) as *const _, stride(x)) } } @@ -238,7 +250,7 @@ pub mod level1 { /// assert_eq!(dznrm2(&x), 7f64.sqrt()) /// ``` #[doc(alias = "cblas_dznrm2")] - pub fn dznrm2>>(x: &T) -> f64 { + pub fn dznrm2> + ?Sized>(x: &T) -> f64 { unsafe { sys::cblas_dznrm2(len(x), as_ptr(x) as *const _, stride(x)) } } @@ -246,53 +258,53 @@ pub mod level1 { /// Return the sum of the modulus of the elements of `x` /// (i.e., its L¹-norm). #[doc(alias = "cblas_dzasum")] - pub fn dzasum>>(x: &T) -> f64 { + pub fn dzasum> + ?Sized>(x: &T) -> f64 { unsafe { sys::cblas_dzasum(len(x), as_ptr(x) as *const _, stride(x)) } } /// Return the index of the element with maximum absolute value. #[doc(alias = "cblas_isamax")] - pub fn isamax>(x: &T) -> usize { + pub fn isamax + ?Sized>(x: &T) -> usize { unsafe { sys::cblas_isamax(len(x), as_ptr(x), stride(x)) } } /// Return the index of the element with maximum absolute value. #[doc(alias = "cblas_idamax")] - pub fn idamax>(x: &T) -> usize { + pub fn idamax + ?Sized>(x: &T) -> usize { unsafe { sys::cblas_idamax(len(x), as_ptr(x), stride(x)) } } #[cfg(feature = "complex")] /// Return the index of the element with maximum modulus. #[doc(alias = "cblas_icamax")] - pub fn icamax>>(x: &T) -> usize { + pub fn icamax> + ?Sized>(x: &T) -> usize { unsafe { sys::cblas_icamax(len(x), as_ptr(x) as *const _, stride(x)) } } #[cfg(feature = "complex")] /// Return the index of the element with maximum modulus. #[doc(alias = "cblas_izamax")] - pub fn izamax>>(x: &T) -> usize { + pub fn izamax> + ?Sized>(x: &T) -> usize { unsafe { sys::cblas_izamax(len(x), as_ptr(x) as *const _, stride(x)) } } /// Swap vectors `x` and `y`. #[doc(alias = "cblas_sswap")] - pub fn sswap>(x: &mut T, y: &mut T) { + pub fn sswap + ?Sized>(x: &mut T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_sswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// Copy the content of `x` into `y`. #[doc(alias = "cblas_scopy")] - pub fn scopy>(x: &T, y: &mut T) { + pub fn scopy + ?Sized>(x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_scopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_saxpy")] - pub fn saxpy>(alpha: f32, x: &T, y: &mut T) { + pub fn saxpy + ?Sized>(alpha: f32, x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_saxpy( @@ -308,21 +320,21 @@ pub mod level1 { /// Swap vectors `x` and `y`. #[doc(alias = "cblas_dswap")] - pub fn dswap>(x: &mut T, y: &mut T) { + pub fn dswap + ?Sized>(x: &mut T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_dswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// Copy the content of `x` into `y`. #[doc(alias = "cblas_dcopy")] - pub fn dcopy>(x: &T, y: &mut T) { + pub fn dcopy + ?Sized>(x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_dcopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_daxpy")] - pub fn daxpy>(alpha: f64, x: &T, y: &mut T) { + pub fn daxpy + ?Sized>(alpha: f64, x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_daxpy( @@ -339,7 +351,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Swap vectors `x` and `y`. #[doc(alias = "cblas_cswap")] - pub fn cswap>>(x: &mut T, y: &mut T) { + pub fn cswap> + ?Sized>(x: &mut T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_cswap( @@ -355,7 +367,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Copy the content of `x` into `y`. #[doc(alias = "cblas_ccopy")] - pub fn ccopy>>(x: &T, y: &mut T) { + pub fn ccopy> + ?Sized>(x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_ccopy( @@ -371,7 +383,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_caxpy")] - pub fn caxpy>>(alpha: &Complex, x: &T, y: &mut T) { + pub fn caxpy> + ?Sized>(alpha: &Complex, x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_caxpy( @@ -388,7 +400,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Swap vectors `x` and `y`. #[doc(alias = "cblas_zswap")] - pub fn zswap>>(x: &mut T, y: &mut T) { + pub fn zswap> + ?Sized>(x: &mut T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zswap( @@ -404,7 +416,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Copy the content of `x` into `y`. #[doc(alias = "cblas_zcopy")] - pub fn zcopy>>(x: &T, y: &mut T) { + pub fn zcopy> + ?Sized>(x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zcopy( @@ -420,7 +432,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_zaxpy")] - pub fn zaxpy>>(alpha: &Complex, x: &T, y: &mut T) { + pub fn zaxpy> + ?Sized>(alpha: &Complex, x: &T, y: &mut T) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zaxpy( @@ -544,7 +556,7 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_srot")] - pub fn srot>(x: &mut T, y: &mut T, c: f32, s: f32) { + pub fn srot + ?Sized>(x: &mut T, y: &mut T, c: f32, s: f32) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_srot( @@ -566,7 +578,7 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_srotm")] - pub fn srotm>(x: &mut T, y: &mut T, h: H) { + pub fn srotm + ?Sized>(x: &mut T, y: &mut T, h: H) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); let p = match h { H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], @@ -664,7 +676,7 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_drot")] - pub fn drot>(x: &mut T, y: &mut T, c: f64, s: f64) { + pub fn drot + ?Sized>(x: &mut T, y: &mut T, c: f64, s: f64) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_drot( @@ -686,7 +698,7 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_drotm")] - pub fn drotm>(x: &mut T, y: &mut T, h: H) { + pub fn drotm + ?Sized>(x: &mut T, y: &mut T, h: H) { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); let p = match h { H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], @@ -708,20 +720,20 @@ pub mod level1 { /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_sscal")] - pub fn sscal>(alpha: f32, x: &mut T) { + pub fn sscal + ?Sized>(alpha: f32, x: &mut T) { unsafe { sys::cblas_sscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_dscal")] - pub fn dscal>(alpha: f64, x: &mut T) { + pub fn dscal + ?Sized>(alpha: f64, x: &mut T) { unsafe { sys::cblas_dscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_cscal")] - pub fn cscal>>(alpha: &Complex, x: &mut T) { + pub fn cscal> + ?Sized>(alpha: &Complex, x: &mut T) { unsafe { sys::cblas_cscal( len(x), @@ -735,7 +747,7 @@ pub mod level1 { #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_zscal")] - pub fn zscal>>(alpha: &Complex, x: &mut T) { + pub fn zscal> + ?Sized>(alpha: &Complex, x: &mut T) { unsafe { sys::cblas_zscal( len(x), @@ -749,14 +761,14 @@ pub mod level1 { #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_csscal")] - pub fn csscal>>(alpha: f32, x: &mut T) { + pub fn csscal> + ?Sized>(alpha: f32, x: &mut T) { unsafe { sys::cblas_csscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } #[cfg(feature = "complex")] /// Multiple each element of a matrix/vector by a constant. #[doc(alias = "cblas_zdscal")] - pub fn zdscal>>(alpha: f64, x: &mut T) { + pub fn zdscal> + ?Sized>(alpha: f64, x: &mut T) { unsafe { sys::cblas_zdscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } } diff --git a/src/fit.rs b/src/fit.rs index 58c6d321..a9c82db5 100644 --- a/src/fit.rs +++ b/src/fit.rs @@ -40,7 +40,10 @@ use crate::{ /// # Ok::<(), rgsl::Value>(()) /// ``` #[doc(alias = "gsl_fit_linear")] -pub fn linear>(x: &T, y: &T) -> Result<(f64, f64, f64, f64, f64, f64), Value> { +pub fn linear(x: &T, y: &T) -> Result<(f64, f64, f64, f64, f64, f64), Value> +where + T: Vector + ?Sized, +{ check_equal_len(x, y)?; let mut c0 = 0.; let mut c1 = 0.; @@ -81,7 +84,7 @@ pub fn linear>(x: &T, y: &T) -> Result<(f64, f64, f64, f64, f64, /// /// Returns `(c0, c1, cov00, cov01, cov11, chisq)`. #[doc(alias = "gsl_fit_wlinear")] -pub fn wlinear>( +pub fn wlinear + ?Sized>( x: &T, w: &T, y: &T, @@ -143,7 +146,7 @@ pub fn linear_est( /// /// Returns `(c1, cov11, sumsq)`. #[doc(alias = "gsl_fit_mul")] -pub fn mul>(x: &T, y: &T) -> Result<(f64, f64, f64), Value> { +pub fn mul + ?Sized>(x: &T, y: &T) -> Result<(f64, f64, f64), Value> { check_equal_len(x, y)?; let mut c1 = 0.; let mut cov11 = 0.; @@ -165,7 +168,7 @@ pub fn mul>(x: &T, y: &T) -> Result<(f64, f64, f64), Value> { /// Returns `(c1, cov11, sumsq)`. #[doc(alias = "gsl_fit_wmul")] -pub fn wmul>(x: &T, w: &T, y: &T) -> Result<(f64, f64, f64), Value> { +pub fn wmul + ?Sized>(x: &T, w: &T, y: &T) -> Result<(f64, f64, f64), Value> { check_equal_len(x, y)?; check_equal_len(x, w)?; let mut c1 = 0.; diff --git a/src/types/complex.rs b/src/types/complex.rs index a41d013a..2d021b36 100644 --- a/src/types/complex.rs +++ b/src/types/complex.rs @@ -20,6 +20,7 @@ pub trait FFFI { fn unwrap(t: T) -> Self; } +//#[deprecated(note = "Use `Complex64` from the `num_complex` create instead")] #[repr(C)] #[derive(Clone, Copy, PartialEq)] pub struct ComplexF64 { diff --git a/src/types/vector.rs b/src/types/vector.rs index 31d87953..e3994fcc 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -612,7 +612,7 @@ macro_rules! impl_AsRef { ($ty: ty) => { impl Vector<$ty> for T where - T: AsRef<[$ty]> + AsMut<[$ty]>, + T: AsRef<[$ty]> + AsMut<[$ty]> + ?Sized, { fn len(&self) -> usize { self.as_ref().len() @@ -640,7 +640,7 @@ impl_AsRef!(Complex); #[inline] pub(crate) fn check_equal_len(x: &T, y: &T) -> Result<(), Value> where - T: Vector, + T: Vector + ?Sized, { if x.len() != y.len() { return Err(Value::Invalid); From cf90e1a1fc2dec634000c7d7f01d4bb41bd24bb5 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Sun, 14 Apr 2024 22:07:00 +0200 Subject: [PATCH 05/28] vector::Vector: use associated functions to avoid conflicts As this trait will be requested to be brought into scope to set strides, it is good that the functions required to define its capabilities do not conflict with methods that the type may originally have. --- src/cblas.rs | 30 ++----------- src/fit.rs | 48 ++++++++++----------- src/types/vector.rs | 102 +++++++++++++++++++++++++++++++------------- 3 files changed, 100 insertions(+), 80 deletions(-) diff --git a/src/cblas.rs b/src/cblas.rs index 3d89736d..cbfa9369 100644 --- a/src/cblas.rs +++ b/src/cblas.rs @@ -2,33 +2,11 @@ // A rust binding for the GSL library by Guillaume Gomez (guillaume1.gomez@gmail.com) // -use crate::vector::Vector; - -/// Return the length of `x` as a `i32` value (to use in CBLAS calls). -#[inline] -fn len + ?Sized>(x: &T) -> i32 { - x.len().try_into().expect("Length must fit in `i32`") -} - -#[inline] -fn as_ptr + ?Sized>(x: &T) -> *const F { - x.as_slice().as_ptr() -} - -#[inline] -fn as_mut_ptr + ?Sized>(x: &mut T) -> *mut F { - x.as_mut_slice().as_mut_ptr() -} - -/// Return the stride of `x` as a `i32` value (to use in CBLAS calls). -#[inline] -fn stride + ?Sized>(x: &T) -> i32 { - x.stride().try_into().expect("Stride must fit in `i32`") -} - pub mod level1 { - use super::{as_mut_ptr, as_ptr, len, stride}; - use crate::vector::{check_equal_len, Vector}; + use crate::vector::{ + Vector, + as_mut_ptr, as_ptr, len, stride, check_equal_len, + }; #[cfg(feature = "complex")] use num_complex::Complex; diff --git a/src/fit.rs b/src/fit.rs index a9c82db5..ab9cc1ea 100644 --- a/src/fit.rs +++ b/src/fit.rs @@ -53,11 +53,11 @@ where let mut sumsq = 0.; let ret = unsafe { ::sys::gsl_fit_linear( - x.as_slice().as_ptr(), - x.stride(), - y.as_slice().as_ptr(), - y.stride(), - x.len(), + T::as_slice(x).as_ptr(), + T::stride(x), + T::as_slice(y).as_ptr(), + T::stride(y), + T::len(x), &mut c0, &mut c1, &mut cov00, @@ -99,13 +99,13 @@ pub fn wlinear + ?Sized>( let mut chisq = 0.; let ret = unsafe { ::sys::gsl_fit_wlinear( - x.as_slice().as_ptr(), - x.stride(), - w.as_slice().as_ptr(), - w.stride(), - y.as_slice().as_ptr(), - y.stride(), - x.len(), + T::as_slice(x).as_ptr(), + T::stride(x), + T::as_slice(w).as_ptr(), + T::stride(w), + T::as_slice(y).as_ptr(), + T::stride(y), + T::len(x), &mut c0, &mut c1, &mut cov00, @@ -153,11 +153,11 @@ pub fn mul + ?Sized>(x: &T, y: &T) -> Result<(f64, f64, f64), Val let mut sumsq = 0.; let ret = unsafe { sys::gsl_fit_mul( - x.as_slice().as_ptr(), - x.stride(), - y.as_slice().as_ptr(), - y.stride(), - x.len(), + T::as_slice(x).as_ptr(), + T::stride(x), + T::as_slice(y).as_ptr(), + T::stride(y), + T::len(x), &mut c1, &mut cov11, &mut sumsq, @@ -176,13 +176,13 @@ pub fn wmul + ?Sized>(x: &T, w: &T, y: &T) -> Result<(f64, f64, f let mut sumsq = 0.; let ret = unsafe { sys::gsl_fit_wmul( - x.as_slice().as_ptr(), - x.stride(), - w.as_slice().as_ptr(), - w.stride(), - y.as_slice().as_ptr(), - y.stride(), - x.len(), + T::as_slice(x).as_ptr(), + T::stride(x), + T::as_slice(w).as_ptr(), + T::stride(w), + T::as_slice(y).as_ptr(), + T::stride(y), + T::len(x), &mut c1, &mut cov11, &mut sumsq, diff --git a/src/types/vector.rs b/src/types/vector.rs index e3994fcc..e9a4b73e 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -48,17 +48,61 @@ use self::num_complex::Complex; /// Elements of the vector are of type `F` (`f32` or `f64`). pub trait Vector { /// Return the number of elements in the vector. - fn len(&self) -> usize; - /// The distance in the slice between two consecutive elements. - fn stride(&self) -> usize; + /// + /// This is an associated function rather than a method to avoid + /// conflicting with similarly named methods. + fn len(x: &Self) -> usize; + + /// The distance in the slice between two consecutive elements of + /// the vector in [`Vector::as_slice`] and [`Vector::as_mut_slice`]. + fn stride(x: &Self) -> usize; + /// Return a reference to the underlying slice. Note that the /// `i`th element of the vector is the `i * stride` element in the /// slice. - fn as_slice(&self) -> &[F]; + /// + /// This is an associated function rather than a method to avoid + /// conflicting with similarly named methods. + fn as_slice(x: &Self) -> &[F]; + /// As [`as_slice`] but mutable. - fn as_mut_slice(&mut self) -> &mut [F]; + fn as_mut_slice(x: &mut Self) -> &mut [F]; } +/// Return the length of `x` as a `i32` value (to use in CBLAS calls). +#[inline] +pub(crate) fn len + ?Sized>(x: &T) -> i32 { + T::len(x).try_into().expect("Length must fit in `i32`") +} + +#[inline] +pub(crate) fn as_ptr + ?Sized>(x: &T) -> *const F { + T::as_slice(x).as_ptr() +} + +#[inline] +pub(crate) fn as_mut_ptr + ?Sized>(x: &mut T) -> *mut F { + T::as_mut_slice(x).as_mut_ptr() +} + +/// Return the stride of `x` as a `i32` value (to use in CBLAS calls). +#[inline] +pub(crate) fn stride + ?Sized>(x: &T) -> i32 { + T::stride(x).try_into().expect("Stride must fit in `i32`") +} + +#[inline] +pub(crate) fn check_equal_len(x: &T, y: &T) -> Result<(), Value> +where + T: Vector + ?Sized, +{ + if T::len(x) != T::len(y) { + return Err(Value::Invalid); + } + Ok(()) +} + + macro_rules! gsl_vec { ($rust_name:ident, $name:ident, $rust_ty:ident) => ( paste! { @@ -578,22 +622,26 @@ impl<'a> [<$rust_name View>]<'a> { } // end of impl block impl Vector<$rust_ty> for $rust_name { - fn len(&self) -> usize { - $rust_name::len(self) + #[inline] + fn len(x: &Self) -> usize { + $rust_name::len(x) } - fn stride(&self) -> usize { - let ptr = self.unwrap_shared(); + #[inline] + fn stride(x: &Self) -> usize { + let ptr = x.unwrap_shared(); if ptr.is_null() { 1 } else { unsafe { (*ptr).stride } } } - fn as_slice(&self) -> &[$rust_ty] { - $rust_name::as_slice(self).unwrap_or(&[]) + #[inline] + fn as_slice(x: &Self) -> &[$rust_ty] { + $rust_name::as_slice(x).unwrap_or(&[]) } - fn as_mut_slice(&mut self) -> &mut [$rust_ty] { - $rust_name::as_slice_mut(self).unwrap_or(&mut []) + #[inline] + fn as_mut_slice(x: &mut Self) -> &mut [$rust_ty] { + $rust_name::as_slice_mut(x).unwrap_or(&mut []) } } @@ -614,17 +662,21 @@ macro_rules! impl_AsRef { where T: AsRef<[$ty]> + AsMut<[$ty]> + ?Sized, { - fn len(&self) -> usize { - self.as_ref().len() + #[inline] + fn len(x: &Self) -> usize { + x.as_ref().len() } - fn stride(&self) -> usize { + #[inline] + fn stride(_: &Self) -> usize { 1 } - fn as_slice(&self) -> &[$ty] { - self.as_ref() + #[inline] + fn as_slice(x: &Self) -> &[$ty] { + x.as_ref() } - fn as_mut_slice(&mut self) -> &mut [$ty] { - self.as_mut() + #[inline] + fn as_mut_slice(x: &mut Self) -> &mut [$ty] { + x.as_mut() } } }; @@ -637,13 +689,3 @@ impl_AsRef!(Complex); #[cfg(feature = "complex")] impl_AsRef!(Complex); -#[inline] -pub(crate) fn check_equal_len(x: &T, y: &T) -> Result<(), Value> -where - T: Vector + ?Sized, -{ - if x.len() != y.len() { - return Err(Value::Invalid); - } - Ok(()) -} From 222c28c70927c7c560689e6074a39459fa524ce7 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 15 Apr 2024 20:58:20 +0200 Subject: [PATCH 06/28] Distinguish mutable and immutable vector traits (for slices) --- src/cblas.rs | 80 ++++++++++++++++++++++--------- src/types/vector.rs | 114 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 160 insertions(+), 34 deletions(-) diff --git a/src/cblas.rs b/src/cblas.rs index cbfa9369..7a335688 100644 --- a/src/cblas.rs +++ b/src/cblas.rs @@ -4,7 +4,7 @@ pub mod level1 { use crate::vector::{ - Vector, + Vector, VectorMut, as_mut_ptr, as_ptr, len, stride, check_equal_len, }; #[cfg(feature = "complex")] @@ -268,21 +268,24 @@ pub mod level1 { /// Swap vectors `x` and `y`. #[doc(alias = "cblas_sswap")] - pub fn sswap + ?Sized>(x: &mut T, y: &mut T) { + pub fn sswap(x: &mut T1, y: &mut T2) + where T1: VectorMut + ?Sized, T2: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_sswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// Copy the content of `x` into `y`. #[doc(alias = "cblas_scopy")] - pub fn scopy + ?Sized>(x: &T, y: &mut T) { + pub fn scopy(x: &T1, y: &mut T2) + where T1: Vector + ?Sized, T2: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_scopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_saxpy")] - pub fn saxpy + ?Sized>(alpha: f32, x: &T, y: &mut T) { + pub fn saxpy(alpha: f32, x: &T1, y: &mut T2) + where T1: Vector + ?Sized, T2: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_saxpy( @@ -298,21 +301,24 @@ pub mod level1 { /// Swap vectors `x` and `y`. #[doc(alias = "cblas_dswap")] - pub fn dswap + ?Sized>(x: &mut T, y: &mut T) { + pub fn dswap(x: &mut T1, y: &mut T2) + where T1: VectorMut + ?Sized, T2: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_dswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// Copy the content of `x` into `y`. #[doc(alias = "cblas_dcopy")] - pub fn dcopy + ?Sized>(x: &T, y: &mut T) { + pub fn dcopy(x: &T1, y: &mut T2) + where T1: Vector + ?Sized, T2: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_dcopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_daxpy")] - pub fn daxpy + ?Sized>(alpha: f64, x: &T, y: &mut T) { + pub fn daxpy(alpha: f64, x: &T1, y: &mut T2) + where T1: Vector + ?Sized, T2: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_daxpy( @@ -329,7 +335,10 @@ pub mod level1 { #[cfg(feature = "complex")] /// Swap vectors `x` and `y`. #[doc(alias = "cblas_cswap")] - pub fn cswap> + ?Sized>(x: &mut T, y: &mut T) { + pub fn cswap(x: &mut T1, y: &mut T2) + where T1: VectorMut> + ?Sized, + T2: VectorMut> + ?Sized + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_cswap( @@ -345,7 +354,10 @@ pub mod level1 { #[cfg(feature = "complex")] /// Copy the content of `x` into `y`. #[doc(alias = "cblas_ccopy")] - pub fn ccopy> + ?Sized>(x: &T, y: &mut T) { + pub fn ccopy(x: &T1, y: &mut T2) + where T1: Vector> + ?Sized, + T2: VectorMut> + ?Sized + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_ccopy( @@ -361,7 +373,10 @@ pub mod level1 { #[cfg(feature = "complex")] /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_caxpy")] - pub fn caxpy> + ?Sized>(alpha: &Complex, x: &T, y: &mut T) { + pub fn caxpy(alpha: &Complex, x: &T1, y: &mut T2) + where T1: Vector> + ?Sized, + T2: VectorMut> + ?Sized + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_caxpy( @@ -378,7 +393,10 @@ pub mod level1 { #[cfg(feature = "complex")] /// Swap vectors `x` and `y`. #[doc(alias = "cblas_zswap")] - pub fn zswap> + ?Sized>(x: &mut T, y: &mut T) { + pub fn zswap(x: &mut T1, y: &mut T2) + where T1: VectorMut> + ?Sized, + T2: VectorMut> + ?Sized + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zswap( @@ -394,7 +412,10 @@ pub mod level1 { #[cfg(feature = "complex")] /// Copy the content of `x` into `y`. #[doc(alias = "cblas_zcopy")] - pub fn zcopy> + ?Sized>(x: &T, y: &mut T) { + pub fn zcopy(x: &T1, y: &mut T2) + where T1: Vector> + ?Sized, + T2: VectorMut> + ?Sized + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zcopy( @@ -410,7 +431,10 @@ pub mod level1 { #[cfg(feature = "complex")] /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_zaxpy")] - pub fn zaxpy> + ?Sized>(alpha: &Complex, x: &T, y: &mut T) { + pub fn zaxpy(alpha: &Complex, x: &T1, y: &mut T2) + where T1: Vector> + ?Sized, + T2: VectorMut> + ?Sized + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_zaxpy( @@ -534,7 +558,8 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_srot")] - pub fn srot + ?Sized>(x: &mut T, y: &mut T, c: f32, s: f32) { + pub fn srot(x: &mut T, y: &mut T, c: f32, s: f32) + where T: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_srot( @@ -556,7 +581,8 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_srotm")] - pub fn srotm + ?Sized>(x: &mut T, y: &mut T, h: H) { + pub fn srotm(x: &mut T, y: &mut T, h: H) + where T: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); let p = match h { H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], @@ -654,7 +680,8 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_drot")] - pub fn drot + ?Sized>(x: &mut T, y: &mut T, c: f64, s: f64) { + pub fn drot(x: &mut T, y: &mut T, c: f64, s: f64) + where T: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_drot( @@ -676,7 +703,8 @@ pub mod level1 { /// /// for all indices i. #[doc(alias = "cblas_drotm")] - pub fn drotm + ?Sized>(x: &mut T, y: &mut T, h: H) { + pub fn drotm(x: &mut T, y: &mut T, h: H) + where T: VectorMut + ?Sized { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); let p = match h { H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], @@ -698,20 +726,23 @@ pub mod level1 { /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_sscal")] - pub fn sscal + ?Sized>(alpha: f32, x: &mut T) { + pub fn sscal(alpha: f32, x: &mut T) + where T: VectorMut + ?Sized { unsafe { sys::cblas_sscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_dscal")] - pub fn dscal + ?Sized>(alpha: f64, x: &mut T) { + pub fn dscal(alpha: f64, x: &mut T) + where T: VectorMut + ?Sized { unsafe { sys::cblas_dscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_cscal")] - pub fn cscal> + ?Sized>(alpha: &Complex, x: &mut T) { + pub fn cscal(alpha: &Complex, x: &mut T) + where T: VectorMut> + ?Sized { unsafe { sys::cblas_cscal( len(x), @@ -725,7 +756,8 @@ pub mod level1 { #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_zscal")] - pub fn zscal> + ?Sized>(alpha: &Complex, x: &mut T) { + pub fn zscal(alpha: &Complex, x: &mut T) + where T: VectorMut> + ?Sized { unsafe { sys::cblas_zscal( len(x), @@ -739,14 +771,16 @@ pub mod level1 { #[cfg(feature = "complex")] /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_csscal")] - pub fn csscal> + ?Sized>(alpha: f32, x: &mut T) { + pub fn csscal(alpha: f32, x: &mut T) + where T: VectorMut> + ?Sized { unsafe { sys::cblas_csscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } #[cfg(feature = "complex")] /// Multiple each element of a matrix/vector by a constant. #[doc(alias = "cblas_zdscal")] - pub fn zdscal> + ?Sized>(alpha: f64, x: &mut T) { + pub fn zdscal(alpha: f64, x: &mut T) + where T: VectorMut> + ?Sized { unsafe { sys::cblas_zdscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } } diff --git a/src/types/vector.rs b/src/types/vector.rs index e9a4b73e..4c410d56 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -32,9 +32,11 @@ vector. use crate::ffi::FFI; use crate::Value; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::marker::PhantomData; +use std::{ + fmt::{self, Debug, Formatter}, + marker::PhantomData, + ops::Range, +}; use paste::paste; @@ -46,6 +48,9 @@ use self::num_complex::Complex; #[allow(clippy::len_without_is_empty)] /// Trait implemented by types that are considered vectors by this crate. /// Elements of the vector are of type `F` (`f32` or `f64`). +/// +/// Bring this trait into scope in order to add methods to specify +/// strides to the types implementing `Vector`. pub trait Vector { /// Return the number of elements in the vector. /// @@ -58,15 +63,93 @@ pub trait Vector { fn stride(x: &Self) -> usize; /// Return a reference to the underlying slice. Note that the - /// `i`th element of the vector is the `i * stride` element in the - /// slice. + /// `i`th element of the vector, `0 <= i < len(x)`, is the + /// `i * stride` element in the slice. /// /// This is an associated function rather than a method to avoid /// conflicting with similarly named methods. fn as_slice(x: &Self) -> &[F]; - /// As [`as_slice`] but mutable. + fn slice(&self, r: Range) -> Option> { + let stride = Self::stride(self); + let slice = Self::as_slice(self); + // FIXME: use std::slice::SliceIndex methods when stable. + if r.end == 0 { + return Some(Slice { vec: &slice[0..0], len: 0, stride: 1 }) + } + let end = (r.end - 1) * stride + 1; + if r.start > r.end || end > slice.len() { + None + } else { + let start = r.start * stride; + Some(Slice { + vec: &slice[start .. end], + len: r.end - r.start, + stride + }) + } + } +} + +pub trait VectorMut: Vector { + /// Same as [`Vector::as_slice`] but mutable. fn as_mut_slice(x: &mut Self) -> &mut [F]; + + /// Same as [`Vector::slice`] but mutable. + fn slice_mut(&mut self, r: Range) -> Option> { + let stride = Self::stride(self); + let slice = Self::as_mut_slice(self); + // FIXME: use std::slice::SliceIndex methods when stable. + if r.end == 0 { + return Some(SliceMut { vec: &mut slice[0..0], len: 0, stride: 1 }) + } + let end = (r.end - 1) * stride + 1; + if r.start > r.end || end > slice.len() { + None + } else { + let start = r.start * stride; + Some(SliceMut { + vec: &mut slice[start .. end], + len: r.end - r.start, + stride + }) + } + } +} + +pub struct Slice<'a, F> { + vec: &'a [F], + len: usize, + stride: usize, +} + +pub struct SliceMut<'a, F> { + vec: &'a mut [F], + len: usize, + stride: usize, +} + +impl<'a, F> Vector for Slice<'a, F> { + #[inline] + fn len(x: &Self) -> usize { x.len } + #[inline] + fn stride(x: &Self) -> usize { x.stride } + #[inline] + fn as_slice(x: &Self) -> &[F] { x.vec } +} + +impl<'a, F> Vector for SliceMut<'a, F> { + #[inline] + fn len(x: &Self) -> usize { x.len } + #[inline] + fn stride(x: &Self) -> usize { x.stride } + #[inline] + fn as_slice(x: &Self) -> &[F] { x.vec } +} + +impl<'a, F> VectorMut for SliceMut<'a, F> { + #[inline] + fn as_mut_slice(x: &mut Self) -> &mut [F] { x.vec } } /// Return the length of `x` as a `i32` value (to use in CBLAS calls). @@ -81,7 +164,7 @@ pub(crate) fn as_ptr + ?Sized>(x: &T) -> *const F { } #[inline] -pub(crate) fn as_mut_ptr + ?Sized>(x: &mut T) -> *mut F { +pub(crate) fn as_mut_ptr + ?Sized>(x: &mut T) -> *mut F { T::as_mut_slice(x).as_mut_ptr() } @@ -92,11 +175,12 @@ pub(crate) fn stride + ?Sized>(x: &T) -> i32 { } #[inline] -pub(crate) fn check_equal_len(x: &T, y: &T) -> Result<(), Value> +pub(crate) fn check_equal_len(x: &T1, y: &T2) -> Result<(), Value> where - T: Vector + ?Sized, + T1: Vector + ?Sized, + T2: Vector + ?Sized, { - if T::len(x) != T::len(y) { + if T1::len(x) != T2::len(y) { return Err(Value::Invalid); } Ok(()) @@ -639,6 +723,8 @@ impl<'a> [<$rust_name View>]<'a> { fn as_slice(x: &Self) -> &[$rust_ty] { $rust_name::as_slice(x).unwrap_or(&[]) } + } + impl VectorMut<$rust_ty> for $rust_name { #[inline] fn as_mut_slice(x: &mut Self) -> &mut [$rust_ty] { $rust_name::as_slice_mut(x).unwrap_or(&mut []) @@ -660,7 +746,7 @@ macro_rules! impl_AsRef { ($ty: ty) => { impl Vector<$ty> for T where - T: AsRef<[$ty]> + AsMut<[$ty]> + ?Sized, + T: AsRef<[$ty]> + ?Sized, { #[inline] fn len(x: &Self) -> usize { @@ -674,6 +760,12 @@ macro_rules! impl_AsRef { fn as_slice(x: &Self) -> &[$ty] { x.as_ref() } + } + + impl VectorMut<$ty> for T + where + T: Vector<$ty> + AsMut<[$ty]> + ?Sized, + { #[inline] fn as_mut_slice(x: &mut Self) -> &mut [$ty] { x.as_mut() From 73b43a9072cce4e3d66a207159641719d25a144b Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 15 Apr 2024 20:58:57 +0200 Subject: [PATCH 07/28] Let the `sort` module use the vector trait (and be safe) --- src/sort.rs | 247 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 171 insertions(+), 76 deletions(-) diff --git a/src/sort.rs b/src/sort.rs index 340f4ba4..87a65992 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -32,39 +32,112 @@ Robert Sedgewick, Algorithms in C, Addison-Wesley, ISBN 0201514257. pub mod vectors { use crate::ffi::FFI; use crate::types::{Permutation, VectorF64}; + use crate::vector::{self, check_equal_len, Vector, VectorMut}; use crate::Value; - /// This function sorts the n elements of the array data with stride stride into ascending numerical order. + /// This function sorts the elements of the array `data` into + /// ascending numerical order. + /// + /// # Examples + /// + /// ``` + /// use rgsl::sort::vectors::sort; + /// let mut data = [4., 1., 3., 2.]; + /// sort(&mut data); + /// assert_eq!(data, [1., 2., 3., 4.]); + /// ``` + /// + /// The same function can also be used with GSL vectors: + /// + /// ``` + /// use rgsl::{vector::VectorF64, sort::vectors::sort}; + /// let mut data = VectorF64::from_slice(&[4., 1., 3., 2.]).unwrap(); + /// sort(&mut data); + /// assert_eq!(data.as_slice().unwrap(), [1., 2., 3., 4.]); + /// ``` #[doc(alias = "gsl_sort")] - pub fn sort(data: &mut [f64], stride: usize, n: usize) { - unsafe { sys::gsl_sort(data.as_mut_ptr(), stride, n) } + pub fn sort(data: &mut T) + where T: VectorMut + ?Sized { + unsafe { sys::gsl_sort( + vector::as_mut_ptr(data), + T::stride(data), + T::len(data)) + } } - /// This function sorts the n elements of the array data1 with stride stride1 into ascending numerical order, while making the same rearrangement - /// of the array data2 with stride stride2, also of size n. + /// This function sorts the elements of the array `data1` into + /// ascending numerical order, while making the same rearrangement + /// of the array `data2`. Panic if `data1` and `data2` do not + /// have the same length. + /// + /// # Example + /// + /// ``` + /// use rgsl::sort::vectors::sort2; + /// let mut data1 = [4., 1., 3., 2.]; + /// let mut data2 = [10., 20., 30., 40.]; + /// sort2(&mut data1, &mut data2); + /// assert_eq!(data1, [1., 2., 3., 4.]); + /// assert_eq!(data2, [20., 40., 30., 10.]); + /// ``` + /// + /// The same function can also be used with GSL vectors: + /// + /// ``` + /// use rgsl::{vector::VectorF64, sort::vectors::sort2}; + /// let mut data1 = VectorF64::from_slice(&[4., 1., 3., 2.]).unwrap(); + /// let mut data2 = VectorF64::from_slice(&[10., 20., 30., 40.]).unwrap(); + /// sort2(&mut data1, &mut data2); + /// assert_eq!(data1.as_slice().unwrap(), [1., 2., 3., 4.]); + /// assert_eq!(data2.as_slice().unwrap(), [20., 40., 30., 10.]); + /// ``` #[doc(alias = "gsl_sort2")] - pub fn sort2(data1: &mut [f64], stride1: usize, data2: &mut [f64], stride2: usize, n: usize) { - unsafe { sys::gsl_sort2(data1.as_mut_ptr(), stride1, data2.as_mut_ptr(), stride2, n) } + pub fn sort2(data1: &mut T1, data2: &mut T2) + where T1: VectorMut + ?Sized, + T2: VectorMut + ?Sized + { + check_equal_len(data1, data2) + .expect("rgsl::sort::sort2: the vectors must have the same length"); + unsafe { sys::gsl_sort2( + vector::as_mut_ptr(data1), + T1::stride(data1), + vector::as_mut_ptr(data2), + T2::stride(data2), + T1::len(data1)) } } /// This function sorts the elements of the vector v into ascending numerical order. #[doc(alias = "gsl_sort_vector")] + #[deprecated(since="8.0", note="Please use `sort` instead")] pub fn sort_vector(v: &mut VectorF64) { unsafe { sys::gsl_sort_vector(v.unwrap_unique()) } } /// This function sorts the elements of the vector v1 into ascending numerical order, while making the same rearrangement of the vector v2. #[doc(alias = "gsl_sort_vector2")] + #[deprecated(since="8.0", note="Please use `sort2` instead")] pub fn sort_vector2(v1: &mut VectorF64, v2: &mut VectorF64) { unsafe { sys::gsl_sort_vector2(v1.unwrap_unique(), v2.unwrap_unique()) } } - /// This function indirectly sorts the n elements of the array data with stride stride into ascending order, storing the resulting - /// permutation in p. The array p must be allocated with a sufficient length to store the n elements of the permutation. The elements of p - /// give the index of the array element which would have been stored in that position if the array had been sorted in place. The array data is not changed. + /// This function indirectly sorts the elements of the array + /// `data` into ascending order, storing the resulting permutation + /// in `p`. The slice `p` must have the same length as `data`. + /// The elements of `p` give the index of the array element which + /// would have been stored in that position if the array had been + /// sorted in place. #[doc(alias = "gsl_sort_index")] - pub fn sort_index(p: &mut [usize], data: &[f64], stride: usize, n: usize) { - unsafe { sys::gsl_sort_index(p.as_mut_ptr(), data.as_ptr(), stride, n) } + pub fn sort_index(p: &mut [usize], data: &T) + where T: Vector + ?Sized { + if p.len() != T::len(data) { + panic!("rgsl::sort::vectors::sort_index: `p` and `data` must have the same length"); + } + unsafe { sys::gsl_sort_index( + p.as_mut_ptr(), + vector::as_ptr(data), + T::stride(data), + T::len(data)) + } } /// This function indirectly sorts the elements of the vector v into ascending order, storing the resulting permutation in p. The elements of p give the @@ -84,105 +157,127 @@ pub mod vectors { pub mod select { use crate::ffi::FFI; use crate::types::VectorF64; + use crate::vector::{self, Vector}; use crate::Value; - /// This function copies the k smallest elements of the array src, of size n and stride stride, in ascending numerical order into the array dest. The size - /// k of the subset must be less than or equal to n. The data src is not modified by this operation. + /// This function copies the `dest.len()` smallest elements of the + /// array `src`, in ascending numerical order into the array + /// `dest`. Panic if `dest.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_smallest")] - pub fn sort_smallest( - dest: &mut [f64], - k: usize, - src: &[f64], - stride: usize, - ) -> Result<(), Value> { - let ret = unsafe { - sys::gsl_sort_smallest(dest.as_mut_ptr(), k, src.as_ptr(), stride, src.len() as _) + pub fn sort_smallest(dest: &mut [f64], src: &T) -> Result<(), Value> + where T: Vector + ?Sized { + if dest.len() > T::len(src) { + panic!("rgsl::sort::select::sort_smallest: `dest.len() > src.len()`"); + } + let ret = unsafe { sys::gsl_sort_smallest( + dest.as_mut_ptr(), dest.len(), + vector::as_ptr(src), T::stride(src), T::len(src)) }; result_handler!(ret, ()) } - /// This function copies the k largest elements of the array src, of size n and stride stride, in descending numerical order into the array dest. k must - /// be less than or equal to n. The data src is not modified by this operation. + /// This function copies the `dest.len()` largest elements of the + /// array `src` in descending numerical order into the array + /// `dest`. Panic if `dest.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_largest")] - pub fn sort_largest( - dest: &mut [f64], - k: usize, - src: &[f64], - stride: usize, - ) -> Result<(), Value> { - let ret = unsafe { - sys::gsl_sort_largest(dest.as_mut_ptr(), k, src.as_ptr(), stride, src.len() as _) + pub fn sort_largest(dest: &mut [f64], src: &T) -> Result<(), Value> + where T: Vector + ?Sized { + if dest.len() > T::len(src) { + panic!("rgsl::sort::select::sort_largest: `dest.len() > src.len()`"); + } + let ret = unsafe { sys::gsl_sort_largest( + dest.as_mut_ptr(), dest.len(), + vector::as_ptr(src), T::stride(src), T::len(src)) }; result_handler!(ret, ()) } - /// This function copies the k smallest or largest elements of the vector v into the array dest. k must be less than or equal to the length of the vector v. + /// This function copies the `dest.len()` smallest elements of the + /// vector `v` into the slice `dest`. Panic if `dest.len()` is + /// larger than the size of `src`. #[doc(alias = "gsl_sort_vector_smallest")] - pub fn sort_vector_smallest(dest: &mut [f64], k: usize, v: &VectorF64) -> Result<(), Value> { - let ret = unsafe { sys::gsl_sort_vector_smallest(dest.as_mut_ptr(), k, v.unwrap_shared()) }; + #[deprecated(since="8.0", note="Please use `sort_smallest` instead")] + pub fn sort_vector_smallest(dest: &mut [f64], v: &VectorF64) -> Result<(), Value> { + if dest.len() > v.len() { + panic!("rgsl::sort::select::sort_vector_smallest: `dest.len() > v.len()`"); + } + let ret = unsafe { sys::gsl_sort_vector_smallest( + dest.as_mut_ptr(), dest.len(), v.unwrap_shared()) }; result_handler!(ret, ()) } - /// This function copies the k smallest or largest elements of the vector v into the array dest. k must be less than or equal to the length of the vector v. + /// This function copies the `dest.len()` largest elements of the + /// vector `v` into the array dest. Panic if `dest.len()` is + /// larger than the size of `src`. #[doc(alias = "gsl_sort_vector_largest")] - pub fn sort_vector_largest(dest: &mut [f64], k: usize, v: &VectorF64) -> Result<(), Value> { - let ret = unsafe { sys::gsl_sort_vector_largest(dest.as_mut_ptr(), k, v.unwrap_shared()) }; + #[deprecated(since="8.0", note="Please use `sort_largest` instead")] + pub fn sort_vector_largest(dest: &mut [f64], v: &VectorF64) -> Result<(), Value> { + if dest.len() > v.len() { + panic!("rgsl::sort::select::sort_vector_largest: `dest.len() > v.len()`"); + } + let ret = unsafe { sys::gsl_sort_vector_largest( + dest.as_mut_ptr(), dest.len(), v.unwrap_shared()) }; result_handler!(ret, ()) } - /// This function stores the indices of the k smallest elements of the array src, of size n and stride stride, in the array p. The indices are chosen so that - /// the corresponding data is in ascending numerical order. k must be less than or equal to n. The data src is not modified by this operation. + /// This function stores the indices of the `p.len()` smallest + /// elements of the vector `src` in the slice `p`. The indices are + /// chosen so that the corresponding data is in ascending + /// numerical order. Panic if `p.len()` is larger than the size + /// of `src`. #[doc(alias = "gsl_sort_smallest_index")] - pub fn sort_smallest_index( - p: &mut [usize], - k: usize, - src: &[f64], - stride: usize, - ) -> Result<(), Value> { - let ret = unsafe { - sys::gsl_sort_smallest_index(p.as_mut_ptr(), k, src.as_ptr(), stride, src.len() as _) + pub fn sort_smallest_index(p: &mut [usize], src: &T) -> Result<(), Value> + where T: Vector + ?Sized { + if p.len() > T::len(src) { + panic!("rgsl::sort::select::sort_smallest_index: `p.len() > src.len()`"); + } + let ret = unsafe { sys::gsl_sort_smallest_index( + p.as_mut_ptr(), p.len(), + vector::as_ptr(src), T::stride(src), T::len(src)) }; result_handler!(ret, ()) } - /// This function stores the indices of the k largest elements of the array src, of size n and stride stride, in the array p. The indices are chosen so that - /// the corresponding data is in descending numerical order. k must be less than or equal to n. The data src is not modified by this operation. + /// This function stores the indices of the `p.len()` largest + /// elements of the vector `src` in the slice `p`. The indices + /// are chosen so that the corresponding data is in descending + /// numerical order. Panic if `p.len()` is larger than the size + /// of `src`. #[doc(alias = "gsl_sort_largest_index")] - pub fn sort_largest_index( - p: &mut [usize], - k: usize, - src: &[f64], - stride: usize, - ) -> Result<(), Value> { - let ret = unsafe { - sys::gsl_sort_largest_index(p.as_mut_ptr(), k, src.as_ptr(), stride, src.len() as _) + pub fn sort_largest_index(p: &mut [usize], src: &T) -> Result<(), Value> + where T: Vector + ?Sized { + let ret = unsafe { sys::gsl_sort_largest_index( + p.as_mut_ptr(), p.len(), + vector::as_ptr(src), T::stride(src), T::len(src)) }; result_handler!(ret, ()) } - /// This function stores the indices of the k smallest or largest elements of the vector v in the array p. k must be less than or equal to the length of - /// the vector v. + /// This function stores the indices of the `p.len()` smallest + /// elements of the vector `v` in the slice `p`. Panic if + /// `p.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_vector_smallest_index")] - pub fn sort_vector_smallest_index( - p: &mut [usize], - k: usize, - v: &VectorF64, - ) -> Result<(), Value> { - let ret = - unsafe { sys::gsl_sort_vector_smallest_index(p.as_mut_ptr(), k, v.unwrap_shared()) }; + #[deprecated(since="8.0", note="Please use `sort_smallest_index` instead")] + pub fn sort_vector_smallest_index(p: &mut [usize], v: &VectorF64) -> Result<(), Value> { + if p.len() > v.len() { + panic!("rgsl::sort::select::sort_vector_smallest_index: `p.len() > v.len()`"); + } + let ret = unsafe { sys::gsl_sort_vector_smallest_index( + p.as_mut_ptr(), p.len(), v.unwrap_shared()) }; result_handler!(ret, ()) } - /// This function stores the indices of the k smallest or largest elements of the vector v in the array p. k must be less than or equal to the length of - /// the vector v. + /// This function stores the indices of the `p.len()` largest + /// elements of the vector `v` in the slice `p`. Panic if + /// `p.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_vector_largest_index")] - pub fn sort_vector_largest_index( - p: &mut [usize], - k: usize, - v: &VectorF64, - ) -> Result<(), Value> { - let ret = - unsafe { sys::gsl_sort_vector_largest_index(p.as_mut_ptr(), k, v.unwrap_shared()) }; + #[deprecated(since="8.0", note="Please use `sort_largest_index` instead")] + pub fn sort_vector_largest_index(p: &mut [usize], v: &VectorF64) -> Result<(), Value> { + if p.len() > v.len() { + panic!("rgsl::sort::select::sort_vector_largest_index: `p.len() > v.len()`"); + } + let ret = unsafe { sys::gsl_sort_vector_largest_index( + p.as_mut_ptr(), p.len(), v.unwrap_shared()) }; result_handler!(ret, ()) } } From a1364af739a352a7a1f8093bac12e0562346094e Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 15 Apr 2024 22:42:05 +0200 Subject: [PATCH 08/28] Use the Vector trait in stats (and add documentation & functions) --- src/stats.rs | 391 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 275 insertions(+), 116 deletions(-) diff --git a/src/stats.rs b/src/stats.rs index ce8de6ee..86012dab 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -2,178 +2,333 @@ // A rust binding for the GSL library by Guillaume Gomez (guillaume1.gomez@gmail.com) // -#[doc(alias = "gsl_stats_wtss")] -pub fn wtss(w: &[f64], wstride: usize, data: &[f64], stride: usize) -> f64 { - unsafe { sys::gsl_stats_wtss(w.as_ptr(), wstride, data.as_ptr(), stride, data.len() as _) } +use crate::vector::{self, Vector}; + +#[cfg(feature = "v2_5")] +use crate::vector::VectorMut; + +// FIXME: Many functions are missing. + +/// # Weighted Samples +/// +/// The functions described in this section allow the computation of +/// statistics for weighted samples. The functions accept a vector of +/// samples, xᵢ, with associated weights, wᵢ. Each sample xᵢ is +/// considered as having been drawn from a Gaussian distribution with +/// variance σᵢ². The sample weight wᵢ is defined as the reciprocal +/// of this variance, wᵢ = 1/σᵢ². Setting a weight to zero +/// corresponds to removing a sample from a dataset. + +/// Return the weighted mean of the dataset `data` using the set of +/// weights `w`. The weighted mean is defined as, +/// ̂μ = (∑ wᵢ xᵢ) / (∑ wᵢ). +/// +/// # Example +/// +/// ``` +/// use rgsl::{stats::wmean, vector::Vector}; +/// let m = wmean(&[1., 1.], &[1., 1.]); +/// assert_eq!(m, 1.); +/// ``` +#[doc(alias = "gsl_stats_wmean")] +pub fn wmean(w: &T, data: &T) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wmean: the size of w and data must be the same"); + } + unsafe { sys:: gsl_stats_wmean( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data)) + } + } -#[doc(alias = "gsl_stats_wtss_m")] -pub fn wtss_m(w: &[f64], wstride: usize, data: &[f64], stride: usize, wmean: f64) -> f64 { - unsafe { - sys::gsl_stats_wtss_m( - w.as_ptr(), - wstride, - data.as_ptr(), - stride, - data.len() as _, - wmean, - ) +/// Returns the estimated variance of the weighted dataset `data` +/// using the set of weights `w`. The estimated variance of a +/// weighted dataset is calculated as, +/// ̂σ² = (∑ wᵢ) / ((∑ wᵢ)² - ∑ wᵢ²) · ∑ wᵢ (xᵢ - ̂μ)². +#[doc(alias = "gsl_stats_wvariance")] +pub fn wvariance(w: &T, data: &T) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wvariance: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wvariance( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data)) } } -#[doc(alias = "gsl_stats_wabsdev")] -pub fn wabsdev(w: &[f64], wstride: usize, data: &[f64], stride: usize) -> f64 { - unsafe { sys::gsl_stats_wabsdev(w.as_ptr(), wstride, data.as_ptr(), stride, data.len() as _) } +/// Returns the estimated variance of the weighted dataset `data` +/// using the given weighted mean `wmean`. +#[doc(alias = "gsl_stats_wvariance_m")] +pub fn wvariance_m(w: &T, data: &T, wmean: f64) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wvariance_m: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wvariance_m( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data), + wmean) + } } -#[doc(alias = "gsl_stats_wskew")] -pub fn wskew(w: &[f64], wstride: usize, data: &[f64], stride: usize) -> f64 { - unsafe { sys::gsl_stats_wskew(w.as_ptr(), wstride, data.as_ptr(), stride, data.len() as _) } +/// Return the standard deviation is defined as the square root of the +/// variance. +#[doc(alias = "gsl_stats_wsd")] +pub fn wsd(w: &T, data: &T) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wsd: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wsd( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data)) + } } -#[doc(alias = "gsl_stats_wkurtosis")] -pub fn wkurtosis(w: &[f64], wstride: usize, data: &[f64], stride: usize) -> f64 { - unsafe { sys::gsl_stats_wkurtosis(w.as_ptr(), wstride, data.as_ptr(), stride, data.len() as _) } +/// Return the standard deviation is defined as the square root of the +/// variance using the given weighted mean `wmean`. +#[doc(alias = "gsl_stats_wsd_m")] +pub fn wsd_m(w: &T, data: &T, wmean: f64) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wsd_m: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wsd_m( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data), + wmean) + } } -#[doc(alias = "gsl_stats_wvariance_m")] -pub fn wvariance_m(w: &[f64], wstride: usize, data: &[f64], stride: usize, wmean: f64) -> f64 { - unsafe { - sys::gsl_stats_wvariance_m( - w.as_ptr(), - wstride, - data.as_ptr(), - stride, - data.len() as _, - wmean, - ) +/// Return an unbiased estimate of the variance of the weighted +/// dataset `data` when the population mean `mean` of the underlying +/// distribution is known a priori. In this case the estimator for +/// the variance replaces the sample mean ̂μ by the known population +/// mean μ: +/// σ² = ∑ wᵢ (xᵢ - μ)² / (∑ wᵢ).. +#[doc(alias = "gsl_stats_wvariance_with_fixed_mean")] +pub fn wvariance_with_fixed_mean(w: &T, data: &T, mean: f64) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wvariance_with_fixed_mean: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wvariance_with_fixed_mean( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data), + mean) + } +} + +/// Return the standard deviation which is defined as the square root +/// of the variance computed by [`wvariance_with_fixed_mean`]. +#[doc(alias = "gsl_stats_wsd_with_fixed_mean")] +pub fn wsd_with_fixed_mean(w: &T, data: &T, mean: f64) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wsd_with_fixed_mean: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wsd_with_fixed_mean( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data), + mean) + } +} + +/// Return the weighted total sum of squares (TSS) of data about the +/// weighted mean. TSS = ∑ wᵢ (xᵢ - wmean)² where the weighted mean +/// wmean is computed internally. +#[doc(alias = "gsl_stats_wtss")] +pub fn wtss(w: &T, data: &T) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wtss: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wtss( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data)) + } +} + +/// Return the weighted total sum of squares (TSS) of data about the +/// weighted mean. TSS = ∑ wᵢ (xᵢ - `wmean`)². +#[doc(alias = "gsl_stats_wtss_m")] +pub fn wtss_m(w: &T, data: &T, wmean: f64) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wtss_m: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wtss_m( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data), + wmean) + } +} + +/// Return the weighted absolute deviation from the weighted mean of +/// data. The absolute deviation from the mean is defined as, +/// absdev = (∑ wᵢ |xᵢ - ̂μ|) / (∑ wᵢ) +#[doc(alias = "gsl_stats_wabsdev")] +pub fn wabsdev(w: &T, data: &T) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wabsdev: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wabsdev( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data)) } } +/// Return the absolute deviation of the weighted dataset data about +/// the given weighted mean `wmean`. #[doc(alias = "gsl_stats_wabsdev_m")] -pub fn wabsdev_m(w: &[f64], wstride: usize, data: &[f64], stride: usize, wmean: f64) -> f64 { - unsafe { - sys::gsl_stats_wabsdev_m( - w.as_ptr(), - wstride, - data.as_ptr(), - stride, - data.len() as _, - wmean, - ) +pub fn wabsdev_m(w: &T, data: &T, wmean: f64) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wabsdev_m: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wabsdev_m( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data), + wmean) + } +} + +/// Return the weighted skewness of the dataset `data`. +/// skew = (∑ wᵢ ((xᵢ - ̂x) / ̂σ)³) / (∑ wᵢ) +#[doc(alias = "gsl_stats_wskew")] +pub fn wskew(w: &T, data: &T) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wskew: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wskew( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data)) } } +/// Return the weighted skewness of the dataset `data` using the given +/// values of the weighted mean and weighted standard deviation, +/// `wmean` and `wsd`. #[doc(alias = "gsl_stats_wskew_m_sd")] -pub fn wskew_m_sd( - w: &[f64], - wstride: usize, - data: &[f64], - stride: usize, - wmean: f64, - wsd: f64, -) -> f64 { - unsafe { - sys::gsl_stats_wskew_m_sd( - w.as_ptr(), - wstride, - data.as_ptr(), - stride, - data.len() as _, - wmean, - wsd, - ) +pub fn wskew_m_sd(w: &T, data: &T, wmean: f64, wsd: f64) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wskew_m_sd: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wskew_m_sd( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data), + wmean, wsd) + } +} + +/// Return the weighted kurtosis of the dataset `data`. +/// kurtosis = (∑ wᵢ ((xᵢ - ̂x) / ̂σ)⁴) / (∑ wᵢ) - 3 +#[doc(alias = "gsl_stats_wkurtosis")] +pub fn wkurtosis(w: &T, data: &T) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wkurtosis: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wkurtosis( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data)) } } +/// Return the weighted kurtosis of the dataset `data` using the given +/// values of the weighted mean and weighted standard deviation, +/// `wmean` and `wsd`. #[doc(alias = "gsl_stats_wkurtosis_m_sd")] -pub fn wkurtosis_m_sd( - w: &[f64], - wstride: usize, - data: &[f64], - stride: usize, - wmean: f64, - wsd: f64, -) -> f64 { - unsafe { - sys::gsl_stats_wkurtosis_m_sd( - w.as_ptr(), - wstride, - data.as_ptr(), - stride, - data.len() as _, - wmean, - wsd, - ) +pub fn wkurtosis_m_sd(w: &T, data: &T, wmean: f64, wsd: f64) -> f64 +where T: Vector + ?Sized { + if T::len(w) != T::len(data) { + panic!("rgsl::stats::wkurtosis_m_sd: the size of w and data must be the same"); + } + unsafe { sys::gsl_stats_wkurtosis_m_sd( + vector::as_ptr(w), T::stride(w), + vector::as_ptr(data), T::stride(data), T::len(data), + wmean, wsd) } } + + #[doc(alias = "gsl_stats_pvariance")] -pub fn pvariance(data1: &[f64], stride1: usize, data2: &[f64], stride2: usize) -> f64 { - unsafe { - sys::gsl_stats_pvariance( - data1.as_ptr(), - stride1, - data1.len() as _, - data2.as_ptr(), - stride2, - data2.len() as _, - ) +pub fn pvariance(data1: &T, data2: &T) -> f64 +where T: Vector + ?Sized { + unsafe { sys::gsl_stats_pvariance( + vector::as_ptr(data1), T::stride(data1), T::len(data1), + vector::as_ptr(data2), T::stride(data2), T::len(data2)) } } #[doc(alias = "gsl_stats_ttest")] -pub fn ttest(data1: &[f64], stride1: usize, data2: &[f64], stride2: usize) -> f64 { - unsafe { - sys::gsl_stats_ttest( - data1.as_ptr(), - stride1, - data1.len() as _, - data2.as_ptr(), - stride2, - data2.len() as _, - ) +pub fn ttest(data1: &T, data2: &T) -> f64 +where T: Vector + ?Sized { + unsafe { sys::gsl_stats_ttest( + vector::as_ptr(data1), T::stride(data1), T::len(data1), + vector::as_ptr(data2), T::stride(data2), T::len(data2)) } } #[doc(alias = "gsl_stats_max")] -pub fn max(data: &[f64], stride: usize) -> f64 { - unsafe { sys::gsl_stats_max(data.as_ptr(), stride, data.len() as _) } +pub fn max(data: &T) -> f64 +where T: Vector + ?Sized { + unsafe { sys::gsl_stats_max( + vector::as_ptr(data), T::stride(data), T::len(data)) } } #[doc(alias = "gsl_stats_min")] -pub fn min(data: &[f64], stride: usize) -> f64 { - unsafe { sys::gsl_stats_min(data.as_ptr(), stride, data.len() as _) } +pub fn min(data: &T) -> f64 +where T: Vector + ?Sized { + unsafe { sys::gsl_stats_min( + vector::as_ptr(data), T::stride(data), T::len(data)) } } /// Returns `(min, max)`. #[doc(alias = "gsl_stats_minmax")] -pub fn stats_minmax(data: &[f64], stride: usize) -> (f64, f64) { +pub fn stats_minmax(data: &T) -> (f64, f64) +where T: Vector + ?Sized { let mut min = 0.; let mut max = 0.; - unsafe { sys::gsl_stats_minmax(&mut min, &mut max, data.as_ptr(), stride, data.len() as _) } + unsafe { sys::gsl_stats_minmax( + &mut min, &mut max, + vector::as_ptr(data), T::stride(data), T::len(data)) } (min, max) } #[doc(alias = "gsl_stats_max_index")] -pub fn max_index(data: &[f64], stride: usize) -> usize { - unsafe { sys::gsl_stats_max_index(data.as_ptr(), stride, data.len() as _) } +pub fn max_index(data: &T) -> usize +where T: Vector + ?Sized { + unsafe { sys::gsl_stats_max_index( + vector::as_ptr(data), T::stride(data), T::len(data)) } } #[doc(alias = "gsl_stats_min_index")] -pub fn min_index(data: &[f64], stride: usize) -> usize { - unsafe { sys::gsl_stats_min_index(data.as_ptr(), stride, data.len() as _) } +pub fn min_index(data: &T) -> usize +where T: Vector + ?Sized { + unsafe { sys::gsl_stats_min_index( + vector::as_ptr(data), T::stride(data), T::len(data)) } } /// Returns `(min, max)`. #[doc(alias = "gsl_stats_minmax_index")] -pub fn stats_minmax_index(data: &[f64], stride: usize) -> (usize, usize) { +pub fn stats_minmax_index(data: &T) -> (usize, usize) +where T: Vector + ?Sized { let mut min = 0; let mut max = 0; unsafe { - sys::gsl_stats_minmax_index(&mut min, &mut max, data.as_ptr(), stride, data.len() as _) + sys::gsl_stats_minmax_index( + &mut min, &mut max, + vector::as_ptr(data), T::stride(data), T::len(data)) } (min, max) } @@ -181,13 +336,17 @@ pub fn stats_minmax_index(data: &[f64], stride: usize) -> (usize, usize) { #[cfg(feature = "v2_5")] #[cfg_attr(feature = "dox", doc(cfg(feature = "v2_5")))] #[doc(alias = "gsl_stats_select")] -pub fn select(data: &mut [f64], stride: usize, k: usize) -> f64 { - unsafe { sys::gsl_stats_select(data.as_mut_ptr(), stride, data.len() as _, k) } +pub fn select(data: &mut T, k: usize) -> f64 +where T: VectorMut + ?Sized { + unsafe { sys::gsl_stats_select( + vector::as_mut_ptr(data), T::stride(data), T::len(data), k) } } #[cfg(feature = "v2_5")] #[cfg_attr(feature = "dox", doc(cfg(feature = "v2_5")))] #[doc(alias = "gsl_stats_median")] -pub fn median(data: &mut [f64], stride: usize) -> f64 { - unsafe { sys::gsl_stats_median(data.as_mut_ptr(), stride, data.len() as _) } +pub fn median(data: &mut T) -> f64 +where T: VectorMut + ?Sized { + unsafe { sys::gsl_stats_median( + vector::as_mut_ptr(data), T::stride(data), T::len(data)) } } From e3bcd9ea6219736c206c27a4c23ee92cd37c4610 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 15 Apr 2024 22:49:26 +0200 Subject: [PATCH 09/28] Make clippy happy --- src/lib.rs | 2 +- src/sort.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 66d687ea..59a4f922 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -118,7 +118,7 @@ pub static ROOT4_DBL_MIN: f64 = 1.221_338_669_755_462_0e-77; pub static ROOT5_DBL_MIN: f64 = 2.947_602_296_969_176_3e-62; pub static ROOT6_DBL_MIN: f64 = 5.303_436_890_579_821_8e-52; -pub static DBL_MAX: f64 = std::f64::MAX; //1.7976931348623156e+308; +pub static DBL_MAX: f64 = f64::MAX; //1.7976931348623156e+308; pub static SQRT_DBL_MAX: f64 = 1.340_780_792_994_259_6e+154; pub static ROOT3_DBL_MAX: f64 = 5.643_803_094_122_289_7e+102; pub static ROOT4_DBL_MAX: f64 = 1.157_920_892_373_162_0e+77; diff --git a/src/sort.rs b/src/sort.rs index 87a65992..e1c06461 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -108,14 +108,14 @@ pub mod vectors { /// This function sorts the elements of the vector v into ascending numerical order. #[doc(alias = "gsl_sort_vector")] - #[deprecated(since="8.0", note="Please use `sort` instead")] + #[deprecated(since="8.0.0", note="Please use `sort` instead")] pub fn sort_vector(v: &mut VectorF64) { unsafe { sys::gsl_sort_vector(v.unwrap_unique()) } } /// This function sorts the elements of the vector v1 into ascending numerical order, while making the same rearrangement of the vector v2. #[doc(alias = "gsl_sort_vector2")] - #[deprecated(since="8.0", note="Please use `sort2` instead")] + #[deprecated(since="8.0.0", note="Please use `sort2` instead")] pub fn sort_vector2(v1: &mut VectorF64, v2: &mut VectorF64) { unsafe { sys::gsl_sort_vector2(v1.unwrap_unique(), v2.unwrap_unique()) } } @@ -196,7 +196,7 @@ pub mod select { /// vector `v` into the slice `dest`. Panic if `dest.len()` is /// larger than the size of `src`. #[doc(alias = "gsl_sort_vector_smallest")] - #[deprecated(since="8.0", note="Please use `sort_smallest` instead")] + #[deprecated(since="8.0.0", note="Please use `sort_smallest` instead")] pub fn sort_vector_smallest(dest: &mut [f64], v: &VectorF64) -> Result<(), Value> { if dest.len() > v.len() { panic!("rgsl::sort::select::sort_vector_smallest: `dest.len() > v.len()`"); @@ -210,7 +210,7 @@ pub mod select { /// vector `v` into the array dest. Panic if `dest.len()` is /// larger than the size of `src`. #[doc(alias = "gsl_sort_vector_largest")] - #[deprecated(since="8.0", note="Please use `sort_largest` instead")] + #[deprecated(since="8.0.0", note="Please use `sort_largest` instead")] pub fn sort_vector_largest(dest: &mut [f64], v: &VectorF64) -> Result<(), Value> { if dest.len() > v.len() { panic!("rgsl::sort::select::sort_vector_largest: `dest.len() > v.len()`"); @@ -257,7 +257,7 @@ pub mod select { /// elements of the vector `v` in the slice `p`. Panic if /// `p.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_vector_smallest_index")] - #[deprecated(since="8.0", note="Please use `sort_smallest_index` instead")] + #[deprecated(since="8.0.0", note="Please use `sort_smallest_index` instead")] pub fn sort_vector_smallest_index(p: &mut [usize], v: &VectorF64) -> Result<(), Value> { if p.len() > v.len() { panic!("rgsl::sort::select::sort_vector_smallest_index: `p.len() > v.len()`"); @@ -271,7 +271,7 @@ pub mod select { /// elements of the vector `v` in the slice `p`. Panic if /// `p.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_vector_largest_index")] - #[deprecated(since="8.0", note="Please use `sort_largest_index` instead")] + #[deprecated(since="8.0.0", note="Please use `sort_largest_index` instead")] pub fn sort_vector_largest_index(p: &mut [usize], v: &VectorF64) -> Result<(), Value> { if p.len() > v.len() { panic!("rgsl::sort::select::sort_vector_largest_index: `p.len() > v.len()`"); From 631b465afce643356b21ca52b8c8dc3348ac19d5 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 15 Apr 2024 22:51:08 +0200 Subject: [PATCH 10/28] Run cargo fmt --- src/cblas.rs | 105 +++++++++---- src/sort.rs | 138 +++++++++++------- src/stats.rs | 348 +++++++++++++++++++++++++++++++------------- src/types/vector.rs | 52 +++++-- 4 files changed, 441 insertions(+), 202 deletions(-) diff --git a/src/cblas.rs b/src/cblas.rs index 7a335688..8781afc5 100644 --- a/src/cblas.rs +++ b/src/cblas.rs @@ -3,10 +3,7 @@ // pub mod level1 { - use crate::vector::{ - Vector, VectorMut, - as_mut_ptr, as_ptr, len, stride, check_equal_len, - }; + use crate::vector::{as_mut_ptr, as_ptr, check_equal_len, len, stride, Vector, VectorMut}; #[cfg(feature = "complex")] use num_complex::Complex; @@ -269,7 +266,10 @@ pub mod level1 { /// Swap vectors `x` and `y`. #[doc(alias = "cblas_sswap")] pub fn sswap(x: &mut T1, y: &mut T2) - where T1: VectorMut + ?Sized, T2: VectorMut + ?Sized { + where + T1: VectorMut + ?Sized, + T2: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_sswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } @@ -277,7 +277,10 @@ pub mod level1 { /// Copy the content of `x` into `y`. #[doc(alias = "cblas_scopy")] pub fn scopy(x: &T1, y: &mut T2) - where T1: Vector + ?Sized, T2: VectorMut + ?Sized { + where + T1: Vector + ?Sized, + T2: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_scopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } @@ -285,7 +288,10 @@ pub mod level1 { /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_saxpy")] pub fn saxpy(alpha: f32, x: &T1, y: &mut T2) - where T1: Vector + ?Sized, T2: VectorMut + ?Sized { + where + T1: Vector + ?Sized, + T2: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_saxpy( @@ -302,7 +308,10 @@ pub mod level1 { /// Swap vectors `x` and `y`. #[doc(alias = "cblas_dswap")] pub fn dswap(x: &mut T1, y: &mut T2) - where T1: VectorMut + ?Sized, T2: VectorMut + ?Sized { + where + T1: VectorMut + ?Sized, + T2: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_dswap(len(x), as_mut_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } @@ -310,7 +319,10 @@ pub mod level1 { /// Copy the content of `x` into `y`. #[doc(alias = "cblas_dcopy")] pub fn dcopy(x: &T1, y: &mut T2) - where T1: Vector + ?Sized, T2: VectorMut + ?Sized { + where + T1: Vector + ?Sized, + T2: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_dcopy(len(x), as_ptr(x), stride(x), as_mut_ptr(y), stride(y)) } } @@ -318,7 +330,10 @@ pub mod level1 { /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_daxpy")] pub fn daxpy(alpha: f64, x: &T1, y: &mut T2) - where T1: Vector + ?Sized, T2: VectorMut + ?Sized { + where + T1: Vector + ?Sized, + T2: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_daxpy( @@ -336,8 +351,9 @@ pub mod level1 { /// Swap vectors `x` and `y`. #[doc(alias = "cblas_cswap")] pub fn cswap(x: &mut T1, y: &mut T2) - where T1: VectorMut> + ?Sized, - T2: VectorMut> + ?Sized + where + T1: VectorMut> + ?Sized, + T2: VectorMut> + ?Sized, { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { @@ -355,8 +371,9 @@ pub mod level1 { /// Copy the content of `x` into `y`. #[doc(alias = "cblas_ccopy")] pub fn ccopy(x: &T1, y: &mut T2) - where T1: Vector> + ?Sized, - T2: VectorMut> + ?Sized + where + T1: Vector> + ?Sized, + T2: VectorMut> + ?Sized, { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { @@ -374,8 +391,9 @@ pub mod level1 { /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_caxpy")] pub fn caxpy(alpha: &Complex, x: &T1, y: &mut T2) - where T1: Vector> + ?Sized, - T2: VectorMut> + ?Sized + where + T1: Vector> + ?Sized, + T2: VectorMut> + ?Sized, { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { @@ -394,8 +412,9 @@ pub mod level1 { /// Swap vectors `x` and `y`. #[doc(alias = "cblas_zswap")] pub fn zswap(x: &mut T1, y: &mut T2) - where T1: VectorMut> + ?Sized, - T2: VectorMut> + ?Sized + where + T1: VectorMut> + ?Sized, + T2: VectorMut> + ?Sized, { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { @@ -413,8 +432,9 @@ pub mod level1 { /// Copy the content of `x` into `y`. #[doc(alias = "cblas_zcopy")] pub fn zcopy(x: &T1, y: &mut T2) - where T1: Vector> + ?Sized, - T2: VectorMut> + ?Sized + where + T1: Vector> + ?Sized, + T2: VectorMut> + ?Sized, { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { @@ -432,8 +452,9 @@ pub mod level1 { /// `y` := `alpha` * `x` + `y`. #[doc(alias = "cblas_zaxpy")] pub fn zaxpy(alpha: &Complex, x: &T1, y: &mut T2) - where T1: Vector> + ?Sized, - T2: VectorMut> + ?Sized + where + T1: Vector> + ?Sized, + T2: VectorMut> + ?Sized, { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { @@ -559,7 +580,9 @@ pub mod level1 { /// for all indices i. #[doc(alias = "cblas_srot")] pub fn srot(x: &mut T, y: &mut T, c: f32, s: f32) - where T: VectorMut + ?Sized { + where + T: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_srot( @@ -582,7 +605,9 @@ pub mod level1 { /// for all indices i. #[doc(alias = "cblas_srotm")] pub fn srotm(x: &mut T, y: &mut T, h: H) - where T: VectorMut + ?Sized { + where + T: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); let p = match h { H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], @@ -681,7 +706,9 @@ pub mod level1 { /// for all indices i. #[doc(alias = "cblas_drot")] pub fn drot(x: &mut T, y: &mut T, c: f64, s: f64) - where T: VectorMut + ?Sized { + where + T: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); unsafe { sys::cblas_drot( @@ -704,7 +731,9 @@ pub mod level1 { /// for all indices i. #[doc(alias = "cblas_drotm")] pub fn drotm(x: &mut T, y: &mut T, h: H) - where T: VectorMut + ?Sized { + where + T: VectorMut + ?Sized, + { check_equal_len(x, y).expect("Vectors `x` and `y` must have the same length"); let p = match h { H::Full { h11, h21, h12, h22 } => [-1.0, h11, h21, h12, h22], @@ -727,14 +756,18 @@ pub mod level1 { /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_sscal")] pub fn sscal(alpha: f32, x: &mut T) - where T: VectorMut + ?Sized { + where + T: VectorMut + ?Sized, + { unsafe { sys::cblas_sscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_dscal")] pub fn dscal(alpha: f64, x: &mut T) - where T: VectorMut + ?Sized { + where + T: VectorMut + ?Sized, + { unsafe { sys::cblas_dscal(len(x), alpha, as_mut_ptr(x), stride(x)) } } @@ -742,7 +775,9 @@ pub mod level1 { /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_cscal")] pub fn cscal(alpha: &Complex, x: &mut T) - where T: VectorMut> + ?Sized { + where + T: VectorMut> + ?Sized, + { unsafe { sys::cblas_cscal( len(x), @@ -757,7 +792,9 @@ pub mod level1 { /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_zscal")] pub fn zscal(alpha: &Complex, x: &mut T) - where T: VectorMut> + ?Sized { + where + T: VectorMut> + ?Sized, + { unsafe { sys::cblas_zscal( len(x), @@ -772,7 +809,9 @@ pub mod level1 { /// Multiply each element of `x` by `alpha`. #[doc(alias = "cblas_csscal")] pub fn csscal(alpha: f32, x: &mut T) - where T: VectorMut> + ?Sized { + where + T: VectorMut> + ?Sized, + { unsafe { sys::cblas_csscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } @@ -780,7 +819,9 @@ pub mod level1 { /// Multiple each element of a matrix/vector by a constant. #[doc(alias = "cblas_zdscal")] pub fn zdscal(alpha: f64, x: &mut T) - where T: VectorMut> + ?Sized { + where + T: VectorMut> + ?Sized, + { unsafe { sys::cblas_zdscal(len(x), alpha, as_mut_ptr(x) as *mut _, stride(x)) } } } diff --git a/src/sort.rs b/src/sort.rs index e1c06461..dbda3a11 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -57,12 +57,10 @@ pub mod vectors { /// ``` #[doc(alias = "gsl_sort")] pub fn sort(data: &mut T) - where T: VectorMut + ?Sized { - unsafe { sys::gsl_sort( - vector::as_mut_ptr(data), - T::stride(data), - T::len(data)) - } + where + T: VectorMut + ?Sized, + { + unsafe { sys::gsl_sort(vector::as_mut_ptr(data), T::stride(data), T::len(data)) } } /// This function sorts the elements of the array `data1` into @@ -93,29 +91,33 @@ pub mod vectors { /// ``` #[doc(alias = "gsl_sort2")] pub fn sort2(data1: &mut T1, data2: &mut T2) - where T1: VectorMut + ?Sized, - T2: VectorMut + ?Sized + where + T1: VectorMut + ?Sized, + T2: VectorMut + ?Sized, { check_equal_len(data1, data2) .expect("rgsl::sort::sort2: the vectors must have the same length"); - unsafe { sys::gsl_sort2( - vector::as_mut_ptr(data1), - T1::stride(data1), - vector::as_mut_ptr(data2), - T2::stride(data2), - T1::len(data1)) } + unsafe { + sys::gsl_sort2( + vector::as_mut_ptr(data1), + T1::stride(data1), + vector::as_mut_ptr(data2), + T2::stride(data2), + T1::len(data1), + ) + } } /// This function sorts the elements of the vector v into ascending numerical order. #[doc(alias = "gsl_sort_vector")] - #[deprecated(since="8.0.0", note="Please use `sort` instead")] + #[deprecated(since = "8.0.0", note = "Please use `sort` instead")] pub fn sort_vector(v: &mut VectorF64) { unsafe { sys::gsl_sort_vector(v.unwrap_unique()) } } /// This function sorts the elements of the vector v1 into ascending numerical order, while making the same rearrangement of the vector v2. #[doc(alias = "gsl_sort_vector2")] - #[deprecated(since="8.0.0", note="Please use `sort2` instead")] + #[deprecated(since = "8.0.0", note = "Please use `sort2` instead")] pub fn sort_vector2(v1: &mut VectorF64, v2: &mut VectorF64) { unsafe { sys::gsl_sort_vector2(v1.unwrap_unique(), v2.unwrap_unique()) } } @@ -128,15 +130,19 @@ pub mod vectors { /// sorted in place. #[doc(alias = "gsl_sort_index")] pub fn sort_index(p: &mut [usize], data: &T) - where T: Vector + ?Sized { + where + T: Vector + ?Sized, + { if p.len() != T::len(data) { panic!("rgsl::sort::vectors::sort_index: `p` and `data` must have the same length"); } - unsafe { sys::gsl_sort_index( - p.as_mut_ptr(), - vector::as_ptr(data), - T::stride(data), - T::len(data)) + unsafe { + sys::gsl_sort_index( + p.as_mut_ptr(), + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } } @@ -165,13 +171,20 @@ pub mod select { /// `dest`. Panic if `dest.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_smallest")] pub fn sort_smallest(dest: &mut [f64], src: &T) -> Result<(), Value> - where T: Vector + ?Sized { + where + T: Vector + ?Sized, + { if dest.len() > T::len(src) { panic!("rgsl::sort::select::sort_smallest: `dest.len() > src.len()`"); } - let ret = unsafe { sys::gsl_sort_smallest( - dest.as_mut_ptr(), dest.len(), - vector::as_ptr(src), T::stride(src), T::len(src)) + let ret = unsafe { + sys::gsl_sort_smallest( + dest.as_mut_ptr(), + dest.len(), + vector::as_ptr(src), + T::stride(src), + T::len(src), + ) }; result_handler!(ret, ()) } @@ -181,13 +194,20 @@ pub mod select { /// `dest`. Panic if `dest.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_largest")] pub fn sort_largest(dest: &mut [f64], src: &T) -> Result<(), Value> - where T: Vector + ?Sized { + where + T: Vector + ?Sized, + { if dest.len() > T::len(src) { panic!("rgsl::sort::select::sort_largest: `dest.len() > src.len()`"); } - let ret = unsafe { sys::gsl_sort_largest( - dest.as_mut_ptr(), dest.len(), - vector::as_ptr(src), T::stride(src), T::len(src)) + let ret = unsafe { + sys::gsl_sort_largest( + dest.as_mut_ptr(), + dest.len(), + vector::as_ptr(src), + T::stride(src), + T::len(src), + ) }; result_handler!(ret, ()) } @@ -196,13 +216,14 @@ pub mod select { /// vector `v` into the slice `dest`. Panic if `dest.len()` is /// larger than the size of `src`. #[doc(alias = "gsl_sort_vector_smallest")] - #[deprecated(since="8.0.0", note="Please use `sort_smallest` instead")] + #[deprecated(since = "8.0.0", note = "Please use `sort_smallest` instead")] pub fn sort_vector_smallest(dest: &mut [f64], v: &VectorF64) -> Result<(), Value> { if dest.len() > v.len() { panic!("rgsl::sort::select::sort_vector_smallest: `dest.len() > v.len()`"); } - let ret = unsafe { sys::gsl_sort_vector_smallest( - dest.as_mut_ptr(), dest.len(), v.unwrap_shared()) }; + let ret = unsafe { + sys::gsl_sort_vector_smallest(dest.as_mut_ptr(), dest.len(), v.unwrap_shared()) + }; result_handler!(ret, ()) } @@ -210,13 +231,14 @@ pub mod select { /// vector `v` into the array dest. Panic if `dest.len()` is /// larger than the size of `src`. #[doc(alias = "gsl_sort_vector_largest")] - #[deprecated(since="8.0.0", note="Please use `sort_largest` instead")] + #[deprecated(since = "8.0.0", note = "Please use `sort_largest` instead")] pub fn sort_vector_largest(dest: &mut [f64], v: &VectorF64) -> Result<(), Value> { if dest.len() > v.len() { panic!("rgsl::sort::select::sort_vector_largest: `dest.len() > v.len()`"); } - let ret = unsafe { sys::gsl_sort_vector_largest( - dest.as_mut_ptr(), dest.len(), v.unwrap_shared()) }; + let ret = unsafe { + sys::gsl_sort_vector_largest(dest.as_mut_ptr(), dest.len(), v.unwrap_shared()) + }; result_handler!(ret, ()) } @@ -227,13 +249,20 @@ pub mod select { /// of `src`. #[doc(alias = "gsl_sort_smallest_index")] pub fn sort_smallest_index(p: &mut [usize], src: &T) -> Result<(), Value> - where T: Vector + ?Sized { + where + T: Vector + ?Sized, + { if p.len() > T::len(src) { panic!("rgsl::sort::select::sort_smallest_index: `p.len() > src.len()`"); } - let ret = unsafe { sys::gsl_sort_smallest_index( - p.as_mut_ptr(), p.len(), - vector::as_ptr(src), T::stride(src), T::len(src)) + let ret = unsafe { + sys::gsl_sort_smallest_index( + p.as_mut_ptr(), + p.len(), + vector::as_ptr(src), + T::stride(src), + T::len(src), + ) }; result_handler!(ret, ()) } @@ -245,10 +274,17 @@ pub mod select { /// of `src`. #[doc(alias = "gsl_sort_largest_index")] pub fn sort_largest_index(p: &mut [usize], src: &T) -> Result<(), Value> - where T: Vector + ?Sized { - let ret = unsafe { sys::gsl_sort_largest_index( - p.as_mut_ptr(), p.len(), - vector::as_ptr(src), T::stride(src), T::len(src)) + where + T: Vector + ?Sized, + { + let ret = unsafe { + sys::gsl_sort_largest_index( + p.as_mut_ptr(), + p.len(), + vector::as_ptr(src), + T::stride(src), + T::len(src), + ) }; result_handler!(ret, ()) } @@ -257,13 +293,14 @@ pub mod select { /// elements of the vector `v` in the slice `p`. Panic if /// `p.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_vector_smallest_index")] - #[deprecated(since="8.0.0", note="Please use `sort_smallest_index` instead")] + #[deprecated(since = "8.0.0", note = "Please use `sort_smallest_index` instead")] pub fn sort_vector_smallest_index(p: &mut [usize], v: &VectorF64) -> Result<(), Value> { if p.len() > v.len() { panic!("rgsl::sort::select::sort_vector_smallest_index: `p.len() > v.len()`"); } - let ret = unsafe { sys::gsl_sort_vector_smallest_index( - p.as_mut_ptr(), p.len(), v.unwrap_shared()) }; + let ret = unsafe { + sys::gsl_sort_vector_smallest_index(p.as_mut_ptr(), p.len(), v.unwrap_shared()) + }; result_handler!(ret, ()) } @@ -271,13 +308,14 @@ pub mod select { /// elements of the vector `v` in the slice `p`. Panic if /// `p.len()` is larger than the size of `src`. #[doc(alias = "gsl_sort_vector_largest_index")] - #[deprecated(since="8.0.0", note="Please use `sort_largest_index` instead")] + #[deprecated(since = "8.0.0", note = "Please use `sort_largest_index` instead")] pub fn sort_vector_largest_index(p: &mut [usize], v: &VectorF64) -> Result<(), Value> { if p.len() > v.len() { panic!("rgsl::sort::select::sort_vector_largest_index: `p.len() > v.len()`"); } - let ret = unsafe { sys::gsl_sort_vector_largest_index( - p.as_mut_ptr(), p.len(), v.unwrap_shared()) }; + let ret = unsafe { + sys::gsl_sort_vector_largest_index(p.as_mut_ptr(), p.len(), v.unwrap_shared()) + }; result_handler!(ret, ()) } } diff --git a/src/stats.rs b/src/stats.rs index 86012dab..5a254db0 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -32,15 +32,21 @@ use crate::vector::VectorMut; /// ``` #[doc(alias = "gsl_stats_wmean")] pub fn wmean(w: &T, data: &T) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wmean: the size of w and data must be the same"); } - unsafe { sys:: gsl_stats_wmean( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data)) + unsafe { + sys::gsl_stats_wmean( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } - } /// Returns the estimated variance of the weighted dataset `data` @@ -49,13 +55,20 @@ where T: Vector + ?Sized { /// ̂σ² = (∑ wᵢ) / ((∑ wᵢ)² - ∑ wᵢ²) · ∑ wᵢ (xᵢ - ̂μ)². #[doc(alias = "gsl_stats_wvariance")] pub fn wvariance(w: &T, data: &T) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wvariance: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wvariance( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data)) + unsafe { + sys::gsl_stats_wvariance( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } } @@ -63,14 +76,21 @@ where T: Vector + ?Sized { /// using the given weighted mean `wmean`. #[doc(alias = "gsl_stats_wvariance_m")] pub fn wvariance_m(w: &T, data: &T, wmean: f64) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wvariance_m: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wvariance_m( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data), - wmean) + unsafe { + sys::gsl_stats_wvariance_m( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + wmean, + ) } } @@ -78,13 +98,20 @@ where T: Vector + ?Sized { /// variance. #[doc(alias = "gsl_stats_wsd")] pub fn wsd(w: &T, data: &T) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wsd: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wsd( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data)) + unsafe { + sys::gsl_stats_wsd( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } } @@ -92,14 +119,21 @@ where T: Vector + ?Sized { /// variance using the given weighted mean `wmean`. #[doc(alias = "gsl_stats_wsd_m")] pub fn wsd_m(w: &T, data: &T, wmean: f64) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wsd_m: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wsd_m( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data), - wmean) + unsafe { + sys::gsl_stats_wsd_m( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + wmean, + ) } } @@ -111,14 +145,21 @@ where T: Vector + ?Sized { /// σ² = ∑ wᵢ (xᵢ - μ)² / (∑ wᵢ).. #[doc(alias = "gsl_stats_wvariance_with_fixed_mean")] pub fn wvariance_with_fixed_mean(w: &T, data: &T, mean: f64) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wvariance_with_fixed_mean: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wvariance_with_fixed_mean( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data), - mean) + unsafe { + sys::gsl_stats_wvariance_with_fixed_mean( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + mean, + ) } } @@ -126,14 +167,21 @@ where T: Vector + ?Sized { /// of the variance computed by [`wvariance_with_fixed_mean`]. #[doc(alias = "gsl_stats_wsd_with_fixed_mean")] pub fn wsd_with_fixed_mean(w: &T, data: &T, mean: f64) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wsd_with_fixed_mean: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wsd_with_fixed_mean( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data), - mean) + unsafe { + sys::gsl_stats_wsd_with_fixed_mean( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + mean, + ) } } @@ -142,13 +190,20 @@ where T: Vector + ?Sized { /// wmean is computed internally. #[doc(alias = "gsl_stats_wtss")] pub fn wtss(w: &T, data: &T) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wtss: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wtss( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data)) + unsafe { + sys::gsl_stats_wtss( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } } @@ -156,14 +211,21 @@ where T: Vector + ?Sized { /// weighted mean. TSS = ∑ wᵢ (xᵢ - `wmean`)². #[doc(alias = "gsl_stats_wtss_m")] pub fn wtss_m(w: &T, data: &T, wmean: f64) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wtss_m: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wtss_m( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data), - wmean) + unsafe { + sys::gsl_stats_wtss_m( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + wmean, + ) } } @@ -172,13 +234,20 @@ where T: Vector + ?Sized { /// absdev = (∑ wᵢ |xᵢ - ̂μ|) / (∑ wᵢ) #[doc(alias = "gsl_stats_wabsdev")] pub fn wabsdev(w: &T, data: &T) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wabsdev: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wabsdev( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data)) + unsafe { + sys::gsl_stats_wabsdev( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } } @@ -186,14 +255,21 @@ where T: Vector + ?Sized { /// the given weighted mean `wmean`. #[doc(alias = "gsl_stats_wabsdev_m")] pub fn wabsdev_m(w: &T, data: &T, wmean: f64) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wabsdev_m: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wabsdev_m( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data), - wmean) + unsafe { + sys::gsl_stats_wabsdev_m( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + wmean, + ) } } @@ -201,13 +277,20 @@ where T: Vector + ?Sized { /// skew = (∑ wᵢ ((xᵢ - ̂x) / ̂σ)³) / (∑ wᵢ) #[doc(alias = "gsl_stats_wskew")] pub fn wskew(w: &T, data: &T) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wskew: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wskew( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data)) + unsafe { + sys::gsl_stats_wskew( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } } @@ -216,14 +299,22 @@ where T: Vector + ?Sized { /// `wmean` and `wsd`. #[doc(alias = "gsl_stats_wskew_m_sd")] pub fn wskew_m_sd(w: &T, data: &T, wmean: f64, wsd: f64) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wskew_m_sd: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wskew_m_sd( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data), - wmean, wsd) + unsafe { + sys::gsl_stats_wskew_m_sd( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + wmean, + wsd, + ) } } @@ -231,13 +322,20 @@ where T: Vector + ?Sized { /// kurtosis = (∑ wᵢ ((xᵢ - ̂x) / ̂σ)⁴) / (∑ wᵢ) - 3 #[doc(alias = "gsl_stats_wkurtosis")] pub fn wkurtosis(w: &T, data: &T) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wkurtosis: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wkurtosis( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data)) + unsafe { + sys::gsl_stats_wkurtosis( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } } @@ -246,89 +344,129 @@ where T: Vector + ?Sized { /// `wmean` and `wsd`. #[doc(alias = "gsl_stats_wkurtosis_m_sd")] pub fn wkurtosis_m_sd(w: &T, data: &T, wmean: f64, wsd: f64) -> f64 -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ if T::len(w) != T::len(data) { panic!("rgsl::stats::wkurtosis_m_sd: the size of w and data must be the same"); } - unsafe { sys::gsl_stats_wkurtosis_m_sd( - vector::as_ptr(w), T::stride(w), - vector::as_ptr(data), T::stride(data), T::len(data), - wmean, wsd) + unsafe { + sys::gsl_stats_wkurtosis_m_sd( + vector::as_ptr(w), + T::stride(w), + vector::as_ptr(data), + T::stride(data), + T::len(data), + wmean, + wsd, + ) } } - - #[doc(alias = "gsl_stats_pvariance")] pub fn pvariance(data1: &T, data2: &T) -> f64 -where T: Vector + ?Sized { - unsafe { sys::gsl_stats_pvariance( - vector::as_ptr(data1), T::stride(data1), T::len(data1), - vector::as_ptr(data2), T::stride(data2), T::len(data2)) +where + T: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_pvariance( + vector::as_ptr(data1), + T::stride(data1), + T::len(data1), + vector::as_ptr(data2), + T::stride(data2), + T::len(data2), + ) } } #[doc(alias = "gsl_stats_ttest")] pub fn ttest(data1: &T, data2: &T) -> f64 -where T: Vector + ?Sized { - unsafe { sys::gsl_stats_ttest( - vector::as_ptr(data1), T::stride(data1), T::len(data1), - vector::as_ptr(data2), T::stride(data2), T::len(data2)) +where + T: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_ttest( + vector::as_ptr(data1), + T::stride(data1), + T::len(data1), + vector::as_ptr(data2), + T::stride(data2), + T::len(data2), + ) } } #[doc(alias = "gsl_stats_max")] pub fn max(data: &T) -> f64 -where T: Vector + ?Sized { - unsafe { sys::gsl_stats_max( - vector::as_ptr(data), T::stride(data), T::len(data)) } +where + T: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_max(vector::as_ptr(data), T::stride(data), T::len(data)) } } #[doc(alias = "gsl_stats_min")] pub fn min(data: &T) -> f64 -where T: Vector + ?Sized { - unsafe { sys::gsl_stats_min( - vector::as_ptr(data), T::stride(data), T::len(data)) } +where + T: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_min(vector::as_ptr(data), T::stride(data), T::len(data)) } } /// Returns `(min, max)`. #[doc(alias = "gsl_stats_minmax")] pub fn stats_minmax(data: &T) -> (f64, f64) -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ let mut min = 0.; let mut max = 0.; - unsafe { sys::gsl_stats_minmax( - &mut min, &mut max, - vector::as_ptr(data), T::stride(data), T::len(data)) } + unsafe { + sys::gsl_stats_minmax( + &mut min, + &mut max, + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) + } (min, max) } #[doc(alias = "gsl_stats_max_index")] pub fn max_index(data: &T) -> usize -where T: Vector + ?Sized { - unsafe { sys::gsl_stats_max_index( - vector::as_ptr(data), T::stride(data), T::len(data)) } +where + T: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_max_index(vector::as_ptr(data), T::stride(data), T::len(data)) } } #[doc(alias = "gsl_stats_min_index")] pub fn min_index(data: &T) -> usize -where T: Vector + ?Sized { - unsafe { sys::gsl_stats_min_index( - vector::as_ptr(data), T::stride(data), T::len(data)) } +where + T: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_min_index(vector::as_ptr(data), T::stride(data), T::len(data)) } } /// Returns `(min, max)`. #[doc(alias = "gsl_stats_minmax_index")] pub fn stats_minmax_index(data: &T) -> (usize, usize) -where T: Vector + ?Sized { +where + T: Vector + ?Sized, +{ let mut min = 0; let mut max = 0; unsafe { sys::gsl_stats_minmax_index( - &mut min, &mut max, - vector::as_ptr(data), T::stride(data), T::len(data)) + &mut min, + &mut max, + vector::as_ptr(data), + T::stride(data), + T::len(data), + ) } (min, max) } @@ -337,16 +475,18 @@ where T: Vector + ?Sized { #[cfg_attr(feature = "dox", doc(cfg(feature = "v2_5")))] #[doc(alias = "gsl_stats_select")] pub fn select(data: &mut T, k: usize) -> f64 -where T: VectorMut + ?Sized { - unsafe { sys::gsl_stats_select( - vector::as_mut_ptr(data), T::stride(data), T::len(data), k) } +where + T: VectorMut + ?Sized, +{ + unsafe { sys::gsl_stats_select(vector::as_mut_ptr(data), T::stride(data), T::len(data), k) } } #[cfg(feature = "v2_5")] #[cfg_attr(feature = "dox", doc(cfg(feature = "v2_5")))] #[doc(alias = "gsl_stats_median")] pub fn median(data: &mut T) -> f64 -where T: VectorMut + ?Sized { - unsafe { sys::gsl_stats_median( - vector::as_mut_ptr(data), T::stride(data), T::len(data)) } +where + T: VectorMut + ?Sized, +{ + unsafe { sys::gsl_stats_median(vector::as_mut_ptr(data), T::stride(data), T::len(data)) } } diff --git a/src/types/vector.rs b/src/types/vector.rs index 4c410d56..d0712fa7 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -75,7 +75,11 @@ pub trait Vector { let slice = Self::as_slice(self); // FIXME: use std::slice::SliceIndex methods when stable. if r.end == 0 { - return Some(Slice { vec: &slice[0..0], len: 0, stride: 1 }) + return Some(Slice { + vec: &slice[0..0], + len: 0, + stride: 1, + }); } let end = (r.end - 1) * stride + 1; if r.start > r.end || end > slice.len() { @@ -83,9 +87,9 @@ pub trait Vector { } else { let start = r.start * stride; Some(Slice { - vec: &slice[start .. end], + vec: &slice[start..end], len: r.end - r.start, - stride + stride, }) } } @@ -97,11 +101,15 @@ pub trait VectorMut: Vector { /// Same as [`Vector::slice`] but mutable. fn slice_mut(&mut self, r: Range) -> Option> { - let stride = Self::stride(self); + let stride = Self::stride(self); let slice = Self::as_mut_slice(self); // FIXME: use std::slice::SliceIndex methods when stable. if r.end == 0 { - return Some(SliceMut { vec: &mut slice[0..0], len: 0, stride: 1 }) + return Some(SliceMut { + vec: &mut slice[0..0], + len: 0, + stride: 1, + }); } let end = (r.end - 1) * stride + 1; if r.start > r.end || end > slice.len() { @@ -109,9 +117,9 @@ pub trait VectorMut: Vector { } else { let start = r.start * stride; Some(SliceMut { - vec: &mut slice[start .. end], + vec: &mut slice[start..end], len: r.end - r.start, - stride + stride, }) } } @@ -131,25 +139,39 @@ pub struct SliceMut<'a, F> { impl<'a, F> Vector for Slice<'a, F> { #[inline] - fn len(x: &Self) -> usize { x.len } + fn len(x: &Self) -> usize { + x.len + } #[inline] - fn stride(x: &Self) -> usize { x.stride } + fn stride(x: &Self) -> usize { + x.stride + } #[inline] - fn as_slice(x: &Self) -> &[F] { x.vec } + fn as_slice(x: &Self) -> &[F] { + x.vec + } } impl<'a, F> Vector for SliceMut<'a, F> { #[inline] - fn len(x: &Self) -> usize { x.len } + fn len(x: &Self) -> usize { + x.len + } #[inline] - fn stride(x: &Self) -> usize { x.stride } + fn stride(x: &Self) -> usize { + x.stride + } #[inline] - fn as_slice(x: &Self) -> &[F] { x.vec } + fn as_slice(x: &Self) -> &[F] { + x.vec + } } impl<'a, F> VectorMut for SliceMut<'a, F> { #[inline] - fn as_mut_slice(x: &mut Self) -> &mut [F] { x.vec } + fn as_mut_slice(x: &mut Self) -> &mut [F] { + x.vec + } } /// Return the length of `x` as a `i32` value (to use in CBLAS calls). @@ -186,7 +208,6 @@ where Ok(()) } - macro_rules! gsl_vec { ($rust_name:ident, $name:ident, $rust_ty:ident) => ( paste! { @@ -780,4 +801,3 @@ impl_AsRef!(f64); impl_AsRef!(Complex); #[cfg(feature = "complex")] impl_AsRef!(Complex); - From c49db2cfbbfd9fc9e296547a108203c89c267376 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 15 Apr 2024 23:04:56 +0200 Subject: [PATCH 11/28] Fix some example using gls::stats --- examples/bspline.rs | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/bspline.rs b/examples/bspline.rs index b4924153..4cf152d6 100644 --- a/examples/bspline.rs +++ b/examples/bspline.rs @@ -64,12 +64,7 @@ fn main() { let chisq = mw.wlinear(&mat_x, &w, &y, &mut c, &mut cov).unwrap(); let dof = N - NCOEFFS; - let tss = stats::wtss( - w.as_slice().expect("as_slice failed"), - 1, - y.as_slice().expect("as_slice failed"), - 1, - ); + let tss = stats::wtss(&w, &y); let rsq = 1. - chisq / tss; eprintln!("chisq/dof = {}, rsq = {}", chisq / dof as f64, rsq); From 22700770e5e450e3be7f6fa57140e94b699312d5 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 23 Dec 2024 13:16:28 +0100 Subject: [PATCH 12/28] Fix documentation warnings --- src/enums.rs | 17 +++++++----- src/types/basis_spline.rs | 10 ++++--- src/types/chebyshev.rs | 5 ++-- src/types/discrete_hankel.rs | 20 ++++++++------ src/types/histograms.rs | 25 ++++++++++------- src/types/integration.rs | 8 +++--- src/types/mod.rs | 2 ++ src/types/monte_carlo.rs | 52 +++++++++++++++++++++--------------- src/types/rng.rs | 7 +++-- src/types/vector.rs | 2 +- 10 files changed, 92 insertions(+), 56 deletions(-) diff --git a/src/enums.rs b/src/enums.rs index 72bb7453..7cbe8b7d 100644 --- a/src/enums.rs +++ b/src/enums.rs @@ -355,17 +355,22 @@ impl From for IntegrationQawo { } } -/// Used by VegasMonteCarlo struct +/// Used by [`VegasParams`][crate::VegasParams]. /// -/// The possible choices are GSL_VEGAS_MODE_IMPORTANCE, GSL_VEGAS_MODE_ -/// STRATIFIED, GSL_VEGAS_MODE_IMPORTANCE_ONLY. This determines whether vegas -/// will use importance sampling or stratified sampling, or whether it can pick on -/// its own. In low dimensions vegas uses strict stratified sampling (more precisely, -/// stratified sampling is chosen if there are fewer than 2 bins per box). +/// This determines whether vegas will use importance sampling or +/// stratified sampling, or whether it can pick on its own. In low +/// dimensions vegas uses strict stratified sampling (more precisely, +/// stratified sampling is chosen if there are fewer than 2 bins per +/// box). #[derive(Clone, PartialEq, PartialOrd, Debug, Copy)] pub enum VegasMode { + /// Importance sampling: allocate more sample points where the + /// integrand is larger. Importance, + /// Exclusively use importance sampling without any stratification. ImportanceOnly, + /// Stratified sampling: divides the integration region into + /// sub-regions and sample each sub-region separately. Stratified, } diff --git a/src/types/basis_spline.rs b/src/types/basis_spline.rs index 4e91968b..ef4bab47 100644 --- a/src/types/basis_spline.rs +++ b/src/types/basis_spline.rs @@ -24,8 +24,9 @@ B_(i,k)(x) = [(x - t_i)/(t_(i+k-1) - t_i)] B_(i,k-1)(x) for i = 0, …, n-1. The common case of cubic B-splines is given by k = 4. The above recurrence relation can be evaluated in a numerically stable way by the de Boor algorithm. -If we define appropriate knots on an interval [a,b] then the B-spline basis functions form a -complete set on that interval. Therefore we can expand a smoothing function as +If we define appropriate knots on an interval \[a,b\] then the +B-spline basis functions form a complete set on that interval. +Therefore we can expand a smoothing function as f(x) = \sum_i c_i B_(i,k)(x) @@ -83,8 +84,9 @@ impl BSpLineWorkspace { result_handler!(ret, ()) } - /// This function assumes uniformly spaced breakpoints on [a,b] and constructs the corresponding - /// knot vector using the previously specified nbreak parameter. + /// This function assumes uniformly spaced breakpoints on \[a,b\] + /// and constructs the corresponding knot vector using the previously + /// specified nbreak parameter. /// The knots are stored in w->knots. #[doc(alias = "gsl_bspline_knots_uniform")] pub fn knots_uniform(&mut self, a: f64, b: f64) -> Result<(), Value> { diff --git a/src/types/chebyshev.rs b/src/types/chebyshev.rs index 1afdd762..788fde1d 100644 --- a/src/types/chebyshev.rs +++ b/src/types/chebyshev.rs @@ -15,8 +15,9 @@ For further information see Abramowitz & Stegun, Chapter 22. ## Definitions -The approximation is made over the range [a,b] using order+1 terms, including the coefficient -`c[0]`. The series is computed using the following convention, +The approximation is made over the range \[a,b\] using order+1 terms, +including the coefficient `c[0]`. The series is computed using the +following convention, f(x) = (c_0 / 2) + \sum_{n=1} c_n T_n(x) diff --git a/src/types/discrete_hankel.rs b/src/types/discrete_hankel.rs index 334bfa96..82783620 100644 --- a/src/types/discrete_hankel.rs +++ b/src/types/discrete_hankel.rs @@ -31,14 +31,18 @@ g_m = (2 / j_(\nu,M)^2) \sum_{k=1}^{M-1} f(j_(\nu,k)/j_(\nu,M)) (J_\nu(j_(\nu,m) j_(\nu,k) / j_(\nu,M)) / J_(\nu+1)(j_(\nu,k))^2). -It is this discrete expression which defines the discrete Hankel transform. The kernel in the -summation above defines the matrix of the \nu-Hankel transform of size M-1. The coefficients of this -matrix, being dependent on \nu and M, must be precomputed and stored; the gsl_dht object -encapsulates this data. The allocation function gsl_dht_alloc returns a gsl_dht object which must be -properly initialized with gsl_dht_init before it can be used to perform transforms on data sample -vectors, for fixed \nu and M, using the gsl_dht_apply function. The implementation allows a scaling -of the fundamental interval, for convenience, so that one can assume the function is defined on the -interval [0,X], rather than the unit interval. +It is this discrete expression which defines the discrete Hankel +transform. The kernel in the summation above defines the matrix of the +\nu-Hankel transform of size M-1. The coefficients of this matrix, +being dependent on \nu and M, must be precomputed and stored; the +gsl_dht object encapsulates this data. The allocation function +gsl_dht_alloc returns a gsl_dht object which must be properly +initialized with gsl_dht_init before it can be used to perform +transforms on data sample vectors, for fixed \nu and M, using the +gsl_dht_apply function. The implementation allows a scaling of the +fundamental interval, for convenience, so that one can assume the +function is defined on the interval \[0,X\], rather than the unit +interval. Notice that by assumption f(t) vanishes at the endpoints of the interval, consistent with the inversion formula and the sampling formula given above. Therefore, this transform corresponds to an diff --git a/src/types/histograms.rs b/src/types/histograms.rs index c104491b..b182e9ff 100644 --- a/src/types/histograms.rs +++ b/src/types/histograms.rs @@ -318,15 +318,22 @@ impl Histogram { } } -ffi_wrapper!(HistogramPdf, *mut sys::gsl_histogram_pdf, gsl_histogram_pdf_free, -"The probability distribution function for a histogram consists of a set of bins which measure the \ -probability of an event falling into a given range of a continuous variable x. A probability \ -distribution function is defined by the following struct, which actually stores the cumulative \ -probability distribution function. This is the natural quantity for generating samples via the \ -inverse transform method, because there is a one-to-one mapping between the cumulative probability \ -distribution and the range [0,1]. It can be shown that by taking a uniform random number in this \ -range and finding its corresponding coordinate in the cumulative probability distribution we obtain \ -samples with the desired probability distribution."); +ffi_wrapper!( + HistogramPdf, + *mut sys::gsl_histogram_pdf, + gsl_histogram_pdf_free, + "The probability distribution function for a histogram consists of \ +a set of bins which measure the probability of an event falling into a \ +given range of a continuous variable x. A probability distribution \ +function is defined by the following struct, which actually stores the \ +cumulative probability distribution function. This is the natural \ +quantity for generating samples via the inverse transform method, \ +because there is a one-to-one mapping between the cumulative probability \ +distribution and the range \\[0,1\\]. It can be shown that by taking \ +a uniform random number in this range and finding its corresponding \ +coordinate in the cumulative probability distribution we obtain \ +samples with the desired probability distribution." +); impl HistogramPdf { /// This function allocates memory for a probability distribution with n bins and returns a pointer to a newly initialized gsl_histogram_pdf diff --git a/src/types/integration.rs b/src/types/integration.rs index a533cb22..b2a51760 100644 --- a/src/types/integration.rs +++ b/src/types/integration.rs @@ -788,9 +788,11 @@ impl GLFixedTable { } } - /// For i in [0, …, t->n - 1], this function obtains the i-th Gauss-Legendre point xi and weight - /// wi on the interval [a,b]. The points and weights are ordered by increasing point value. A - /// function f may be integrated on [a,b] by summing wi * f(xi) over i. + /// For i in \[0, …, t->n - 1\], this function obtains the i-th + /// Gauss-Legendre point xi and weight wi on the interval \[a,b\]. + /// The points and weights are ordered by increasing point value. + /// A function f may be integrated on \[a,b\] by summing wi * + /// f(xi) over i. /// /// Returns `(xi, wi)` if it succeeded. #[doc(alias = "gsl_integration_glfixed_point")] diff --git a/src/types/mod.rs b/src/types/mod.rs index e49f68b5..d1153d14 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -2,6 +2,8 @@ // A rust binding for the GSL library by Guillaume Gomez (guillaume1.gomez@gmail.com) // +//! GLS types (reexported into the root `rgsl`). + pub use self::basis_spline::BSpLineWorkspace; pub use self::chebyshev::ChebSeries; diff --git a/src/types/monte_carlo.rs b/src/types/monte_carlo.rs index 74ab3a4e..11b9c1ba 100644 --- a/src/types/monte_carlo.rs +++ b/src/types/monte_carlo.rs @@ -474,30 +474,40 @@ pub struct VegasParams<'a> { } impl<'a> VegasParams<'a> { - /// alpha: The parameter alpha controls the stiffness of the rebinning algorithm. It is typically - /// set between one and two. A value of zero prevents rebinning of the grid. The default - /// value is 1.5. + /// `alpha`: The parameter `alpha` controls the stiffness of the + /// rebinning algorithm. It is typically set between one and two. + /// A value of zero prevents rebinning of the grid. The default + /// value is `1.5`. /// - /// iterations: The number of iterations to perform for each call to the routine. The default value - /// is 5 iterations. + /// `iterations`: The number of iterations to perform for each + /// call to the routine. The default value is `5` iterations. /// - /// stage: Setting this determines the stage of the calculation. Normally, stage = 0 which begins - /// with a new uniform grid and empty weighted average. Calling vegas with stage = - /// 1 retains the grid from the previous run but discards the weighted average, so that - /// one can “tune” the grid using a relatively small number of points and then do a large - /// run with stage = 1 on the optimized grid. Setting stage = 2 keeps the grid and the - /// weighted average from the previous run, but may increase (or decrease) the number - /// of histogram bins in the grid depending on the number of calls available. Choosing - /// stage = 3 enters at the main loop, so that nothing is changed, and is equivalent to - /// performing additional iterations in a previous call. + /// `stage`: Setting this determines the stage of the + /// calculation. Normally, stage = 0 which begins with a new + /// uniform grid and empty weighted average. Calling vegas with + /// stage = 1 retains the grid from the previous run but discards + /// the weighted average, so that one can “tune” the grid using a + /// relatively small number of points and then do a large run with + /// stage = 1 on the optimized grid. Setting stage = 2 keeps the + /// grid and the weighted average from the previous run, but may + /// increase (or decrease) the number of histogram bins in the + /// grid depending on the number of calls available. Choosing + /// stage = 3 enters at the main loop, so that nothing is changed, + /// and is equivalent to performing additional iterations in a + /// previous call. /// - /// mode: The possible choices are GSL_VEGAS_MODE_IMPORTANCE, GSL_VEGAS_MODE_ - /// STRATIFIED, GSL_VEGAS_MODE_IMPORTANCE_ONLY. This determines whether vegas - /// will use importance sampling or stratified sampling, or whether it can pick on - /// its own. In low dimensions vegas uses strict stratified sampling (more precisely, - /// stratified sampling is chosen if there are fewer than 2 bins per box). + /// `mode`: The possible choices are + /// [`VegasMode::Importance`][crate::VegasMode::Importance], + /// [`VegasMode::Stratified`][crate::VegasMode::Stratified], and + /// [`VegasMode::ImportanceOnly`][crate::VegasMode::ImportanceOnly]. + /// This determines whether vegas will use importance sampling or + /// stratified sampling, or whether it can pick on its own. In low + /// dimensions vegas uses strict stratified sampling (more + /// precisely, stratified sampling is chosen if there are fewer + /// than 2 bins per box). /// - /// verbosity + stream: These parameters set the level of information printed by vegas. + /// `verbosity` and `stream`: These parameters set the level of + /// information printed by vegas. pub fn new( alpha: f64, iterations: usize, @@ -505,7 +515,7 @@ impl<'a> VegasParams<'a> { mode: crate::VegasMode, verbosity: VegasVerbosity, stream: Option<&'a mut crate::IOStream>, - ) -> Result { + ) -> Result, String> { if !verbosity.is_off() && stream.is_none() { return Err( "rust-GSL: need to provide an input stream for Vegas Monte Carlo \ diff --git a/src/types/rng.rs b/src/types/rng.rs index 8d8c771f..2f174101 100644 --- a/src/types/rng.rs +++ b/src/types/rng.rs @@ -110,8 +110,11 @@ impl Rng { unsafe { sys::gsl_rng_set(self.unwrap_unique(), s as _) } } - /// This function returns a random integer from the generator r. The minimum and maximum values depend on the algorithm used, but all integers in the range [min,max] are equally likely. - /// The values of min and max can be determined using the auxiliary functions gsl_rng_max (r) and gsl_rng_min (r). + /// This function returns a random integer from the generator r. + /// The minimum and maximum values depend on the algorithm used, + /// but all integers in the range \[min,max\] are equally likely. + /// The values of min and max can be determined using the + /// auxiliary functions gsl_rng_max (r) and gsl_rng_min (r). #[doc(alias = "gsl_rng_get")] pub fn get(&mut self) -> usize { unsafe { sys::gsl_rng_get(self.unwrap_shared()) as _ } diff --git a/src/types/vector.rs b/src/types/vector.rs index d0712fa7..9400f864 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -59,7 +59,7 @@ pub trait Vector { fn len(x: &Self) -> usize; /// The distance in the slice between two consecutive elements of - /// the vector in [`Vector::as_slice`] and [`Vector::as_mut_slice`]. + /// the vector in [`Vector::as_slice`] and [`VectorMut::as_mut_slice`]. fn stride(x: &Self) -> usize; /// Return a reference to the underlying slice. Note that the From 09b2689e251ea18d525b72362f432429b690cb7a Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Mon, 23 Dec 2024 13:20:11 +0100 Subject: [PATCH 13/28] CI: do not require a specific version of Ubuntu & use "sudo" --- .github/workflows/CI.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 216720e0..a50f6338 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -8,16 +8,14 @@ name: CI jobs: build-linux: runs-on: ubuntu-latest - container: - image: ubuntu:23.10 strategy: matrix: rust: - stable - nightly steps: - - run: apt-get update -y - - run: apt-get install -y libgsl0-dev curl build-essential python3 + - run: sudo apt-get update -y + - run: sudo apt-get install -y libgsl0-dev curl build-essential python3 - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: From e017f9ce390a067788837eaa5f05bee631127698 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Wed, 25 Dec 2024 12:01:44 +0100 Subject: [PATCH 14/28] Make the traits Vector and VectorMut unsafe An incorrect implementation of those traits may result in a out-of-bounds access in the C code, whence invalidating safety. --- src/types/vector.rs | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/types/vector.rs b/src/types/vector.rs index 9400f864..9148a9b6 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -51,7 +51,11 @@ use self::num_complex::Complex; /// /// Bring this trait into scope in order to add methods to specify /// strides to the types implementing `Vector`. -pub trait Vector { +/// +/// # Safety +/// One must make sore that `(len - 1) * stride` does not exceed the +/// length of the underlying slice. +pub unsafe trait Vector { /// Return the number of elements in the vector. /// /// This is an associated function rather than a method to avoid @@ -71,6 +75,10 @@ pub trait Vector { fn as_slice(x: &Self) -> &[F]; fn slice(&self, r: Range) -> Option> { + // The fields of `Slice` are not public, hence there is no way + // to implement this function outside this module. This + // guarantee that there is no out-of-bounds access as soon as + // the above methods are correctly implemented. let stride = Self::stride(self); let slice = Self::as_slice(self); // FIXME: use std::slice::SliceIndex methods when stable. @@ -95,7 +103,7 @@ pub trait Vector { } } -pub trait VectorMut: Vector { +pub unsafe trait VectorMut: Vector { /// Same as [`Vector::as_slice`] but mutable. fn as_mut_slice(x: &mut Self) -> &mut [F]; @@ -137,7 +145,7 @@ pub struct SliceMut<'a, F> { stride: usize, } -impl<'a, F> Vector for Slice<'a, F> { +unsafe impl<'a, F> Vector for Slice<'a, F> { #[inline] fn len(x: &Self) -> usize { x.len @@ -152,7 +160,7 @@ impl<'a, F> Vector for Slice<'a, F> { } } -impl<'a, F> Vector for SliceMut<'a, F> { +unsafe impl<'a, F> Vector for SliceMut<'a, F> { #[inline] fn len(x: &Self) -> usize { x.len @@ -167,7 +175,7 @@ impl<'a, F> Vector for SliceMut<'a, F> { } } -impl<'a, F> VectorMut for SliceMut<'a, F> { +unsafe impl<'a, F> VectorMut for SliceMut<'a, F> { #[inline] fn as_mut_slice(x: &mut Self) -> &mut [F] { x.vec @@ -726,7 +734,7 @@ impl<'a> [<$rust_name View>]<'a> { } } // end of impl block - impl Vector<$rust_ty> for $rust_name { + unsafe impl Vector<$rust_ty> for $rust_name { #[inline] fn len(x: &Self) -> usize { $rust_name::len(x) @@ -745,7 +753,7 @@ impl<'a> [<$rust_name View>]<'a> { $rust_name::as_slice(x).unwrap_or(&[]) } } - impl VectorMut<$rust_ty> for $rust_name { + unsafe impl VectorMut<$rust_ty> for $rust_name { #[inline] fn as_mut_slice(x: &mut Self) -> &mut [$rust_ty] { $rust_name::as_slice_mut(x).unwrap_or(&mut []) @@ -765,7 +773,7 @@ gsl_vec!(VectorU32, gsl_vector_uint, u32); macro_rules! impl_AsRef { ($ty: ty) => { - impl Vector<$ty> for T + unsafe impl Vector<$ty> for T where T: AsRef<[$ty]> + ?Sized, { @@ -783,7 +791,7 @@ macro_rules! impl_AsRef { } } - impl VectorMut<$ty> for T + unsafe impl VectorMut<$ty> for T where T: Vector<$ty> + AsMut<[$ty]> + ?Sized, { From 88756324530c9b9958d62fd138f6d9f6a7fb2b6c Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Wed, 25 Dec 2024 13:49:25 +0100 Subject: [PATCH 15/28] Use `Vector` for fast Fourier transform algorithms --- src/types/fast_fourier_transforms.rs | 53 +++++++++++++--------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/src/types/fast_fourier_transforms.rs b/src/types/fast_fourier_transforms.rs index 3dbfd9a6..fe61d853 100644 --- a/src/types/fast_fourier_transforms.rs +++ b/src/types/fast_fourier_transforms.rs @@ -3,7 +3,10 @@ // use crate::ffi::FFI; -use crate::Value; +use crate::{ + vector::VectorMut, + Value +}; use paste::paste; macro_rules! gsl_fft_wavetable { @@ -74,18 +77,16 @@ impl $complex_rust_name { } #[doc(alias = $name $($extra)? _forward)] - pub fn forward( + pub fn forward + ?Sized>( &mut self, - data: &mut [$ty], - stride: usize, - n: usize, + data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { let ret = unsafe { sys::[<$name $($extra)? _forward>]( - data.as_mut_ptr(), - stride, - n, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), wavetable.unwrap_shared(), self.unwrap_unique(), ) @@ -94,19 +95,17 @@ impl $complex_rust_name { } #[doc(alias = $name $($extra)? _transform)] - pub fn transform( + pub fn transform + ?Sized>( &mut self, - data: &mut [$ty], - stride: usize, - n: usize, + data: &mut V, wavetable: &$rust_name, sign: crate::FftDirection, ) -> Result<(), Value> { let ret = unsafe { sys::[<$name $($extra)? _transform>]( - data.as_mut_ptr(), - stride, - n, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), wavetable.unwrap_shared(), self.unwrap_unique(), sign.into(), @@ -116,18 +115,16 @@ impl $complex_rust_name { } #[doc(alias = $name $($extra)? _backward)] - pub fn backward( + pub fn backward + ?Sized>( &mut self, - data: &mut [$ty], - stride: usize, - n: usize, + data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { let ret = unsafe { sys::[<$name $($extra)? _backward>]( - data.as_mut_ptr(), - stride, - n, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), wavetable.unwrap_shared(), self.unwrap_unique(), ) @@ -136,18 +133,16 @@ impl $complex_rust_name { } #[doc(alias = $name $($extra)? _inverse)] - pub fn inverse( + pub fn inverse + ?Sized>( &mut self, - data: &mut [$ty], - stride: usize, - n: usize, + data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { let ret = unsafe { sys::[<$name $($extra)? _inverse>]( - data.as_mut_ptr(), - stride, - n, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), wavetable.unwrap_shared(), self.unwrap_unique(), ) From 718b1d27d3734a07a93145bb039504328ce0ec68 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Wed, 25 Dec 2024 13:58:00 +0100 Subject: [PATCH 16/28] Fix warnings for doc generation --- src/eigen.rs | 2 +- src/linear_algebra.rs | 4 ++-- src/minimizer.rs | 2 +- src/physical_constant.rs | 4 ++-- src/statistics.rs | 4 ++-- src/types/basis_spline.rs | 2 +- src/types/permutation.rs | 2 ++ src/types/rng.rs | 6 +++--- src/types/series_acceleration.rs | 2 +- src/types/vector.rs | 2 +- src/types/vector_complex.rs | 2 +- src/types/wavelet_transforms.rs | 2 +- 12 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/eigen.rs b/src/eigen.rs index 5d15267a..595de262 100644 --- a/src/eigen.rs +++ b/src/eigen.rs @@ -14,7 +14,7 @@ C. Moler, G. Stewart, “An Algorithm for Generalized Matrix Eigenvalue Problems Eigensystem routines for very large matrices can be found in the Fortran library LAPACK. The LAPACK library is described in, LAPACK Users’ Guide (Third Edition, 1999), Published by SIAM, ISBN 0-89871-447-8. -http://www.netlib.org/lapack + The LAPACK source code can be found at the website above along with an online copy of the users guide. !*/ diff --git a/src/linear_algebra.rs b/src/linear_algebra.rs index 58b0b867..2f8497c7 100644 --- a/src/linear_algebra.rs +++ b/src/linear_algebra.rs @@ -145,7 +145,7 @@ G. H. Golub, C. F. Van Loan, Matrix Computations (3rd Ed, 1996), Johns Hopkins U The LAPACK library is described in the following manual, LAPACK Users’ Guide (Third Edition, 1999), Published by SIAM, ISBN 0-89871-447-8. -http://www.netlib.org/lapack + The LAPACK source code can be found at the website above, along with an online copy of the users guide. @@ -158,7 +158,7 @@ J.C. Nash, “A one-sided transformation method for the singular value decomposi 1 (1975), p 74–76 J.C. Nash and S. Shlien “Simple algorithms for the partial singular value decomposition”, Computer Journal, Volume 30 (1987), p 268–275. James Demmel, Krešimir Veselić, “Jacobi’s Method is more accurate than QR”, Lapack Working Note 15 (LAWN-15), October 1989. Available from netlib, -http://www.netlib.org/lapack/ in the lawns or lawnspdf directories. + in the lawns or lawnspdf directories. !*/ use crate::enums; diff --git a/src/minimizer.rs b/src/minimizer.rs index b27958a4..be2da905 100644 --- a/src/minimizer.rs +++ b/src/minimizer.rs @@ -11,7 +11,7 @@ use crate::Value; /// |a - b| < epsabs + epsrel min(|a|,|b|) /// ``` /// -/// when the interval x = [a,b] does not include the origin. If the interval includes the origin then \min(|a|,|b|) is replaced by zero ( +/// when the interval x = \[a,b\] does not include the origin. If the interval includes the origin then \min(|a|,|b|) is replaced by zero ( /// which is the minimum value of |x| over the interval). This ensures that the relative error is accurately estimated for minima close to /// the origin. /// diff --git a/src/physical_constant.rs b/src/physical_constant.rs index c98c33ab..49727741 100644 --- a/src/physical_constant.rs +++ b/src/physical_constant.rs @@ -16,8 +16,8 @@ information on the values of physical constants is also available from the NIST P.J. Mohr, B.N. Taylor, D.B. Newell, “CODATA Recommended Values of the Fundamental Physical Constants: 2006”, Reviews of Modern Physics, 80(2), pp. 633–730 (2008). -http://www.physics.nist.gov/cuu/Constants/index.html -http://physics.nist.gov/Pubs/SP811/appenB9.html + + !*/ // Fundamental Constants diff --git a/src/statistics.rs b/src/statistics.rs index 4f5782bc..95521a33 100644 --- a/src/statistics.rs +++ b/src/statistics.rs @@ -41,7 +41,7 @@ For physicists the Particle Data Group provides useful reviews of Probability an Annual Review of Particle Physics. Review of Particle Properties R.M. Barnett et al., Physical Review D54, 1 (1996) -The Review of Particle Physics is available online at the website http://pdg.lbl.gov/. +The Review of Particle Physics is available online at the website . !*/ /// This function returns the arithmetic mean of data, a dataset of length n with stride stride. The @@ -573,7 +573,7 @@ pub fn median_from_sorted_data(data: &[f64], stride: usize, n: usize) -> f64 { /// /// where i is floor((n - 1)f) and \delta is (n-1)f - i. /// -/// Thus the minimum value of the array (data[0*stride]) is given by f equal to zero, the maximum +/// Thus the minimum value of the array (data\[0*stride\]) is given by f equal to zero, the maximum /// value (data[(n-1)*stride]) is given by f equal to one and the median value is given by f equal /// to 0.5. Since the algorithm for computing quantiles involves interpolation this function always /// returns a floating-point number, even for integer data types. diff --git a/src/types/basis_spline.rs b/src/types/basis_spline.rs index ef4bab47..e3e37aaa 100644 --- a/src/types/basis_spline.rs +++ b/src/types/basis_spline.rs @@ -45,7 +45,7 @@ Richard W. Johnson, Higher order B-spline collocation at the Greville abscissae. Mathematics. vol. 52, 2005, 63–75. A large collection of B-spline routines is available in the PPPACK library available at -http://www.netlib.org/pppack, which is also part of SLATEC. +, which is also part of SLATEC. !*/ use crate::ffi::FFI; diff --git a/src/types/permutation.rs b/src/types/permutation.rs index 38650d14..b20cf7d2 100644 --- a/src/types/permutation.rs +++ b/src/types/permutation.rs @@ -10,6 +10,8 @@ use crate::{MatrixComplexF32, MatrixComplexF64, MatrixF32, VectorF64}; use std::fmt::{self, Debug, Formatter}; use std::slice; +// FIXME: Permutations have the same representation as vectors. +// Do we want to wrap vectors? (The wrapping is to preserve invariants.) ffi_wrapper!(Permutation, *mut sys::gsl_permutation, gsl_permutation_free); /// ## Permutations in cyclic form diff --git a/src/types/rng.rs b/src/types/rng.rs index 2f174101..aa106a07 100644 --- a/src/types/rng.rs +++ b/src/types/rng.rs @@ -54,16 +54,16 @@ Donald E. Knuth, The Art of Computer Programming: Seminumerical Algorithms (Vol Further information is available in the review paper written by Pierre L’Ecuyer, P. L’Ecuyer, “Random Number Generation”, Chapter 4 of the Handbook on Simulation, Jerry Banks Ed., Wiley, 1998, 93–137. -http://www.iro.umontreal.ca/~lecuyer/papers.html + The source code for the DIEHARD random number generator tests is also available online, DIEHARD source code G. Marsaglia, -http://stat.fsu.edu/pub/diehard/ + A comprehensive set of random number generator tests is available from NIST, NIST Special Publication 800-22, “A Statistical Test Suite for the Validation of Random Number Generators and Pseudo Random Number Generators for Cryptographic Applications”. -http://csrc.nist.gov/rng/ + ## Acknowledgements diff --git a/src/types/series_acceleration.rs b/src/types/series_acceleration.rs index c634c91c..da9be683 100644 --- a/src/types/series_acceleration.rs +++ b/src/types/series_acceleration.rs @@ -46,7 +46,7 @@ D. Levin, Development of Non-Linear Transformations for Improving Convergence of A review paper on the Levin Transform is available online, -Herbert H. H. Homeier, Scalar Levin-Type Sequence Transformations, http://arxiv.org/abs/math/0005209. +Herbert H. H. Homeier, Scalar Levin-Type Sequence Transformations, . !*/ use crate::ffi::FFI; diff --git a/src/types/vector.rs b/src/types/vector.rs index 9148a9b6..4255c1b4 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -689,7 +689,7 @@ impl<'a> [<$rust_name View>]<'a> { /// n elements with a step-size of stride from one element to the next in the original /// array. Mathematically, the i-th element of the new vector v’ is given by, /// - /// v'(i) = base[i*stride] + /// v'(i) = base\[i*stride\] /// /// where the index i runs from 0 to n-1. /// diff --git a/src/types/vector_complex.rs b/src/types/vector_complex.rs index e71c2a5b..bf08d164 100644 --- a/src/types/vector_complex.rs +++ b/src/types/vector_complex.rs @@ -460,7 +460,7 @@ macro_rules! gsl_vec_complex { /// n elements with a step-size of stride from one element to the next in the original /// array. Mathematically, the i-th element of the new vector v’ is given by, /// - /// v'(i) = base[i*stride] + /// v'(i) = base\[i*stride\] /// /// where the index i runs from 0 to n-1. /// diff --git a/src/types/wavelet_transforms.rs b/src/types/wavelet_transforms.rs index 6e1877a8..2a765a02 100644 --- a/src/types/wavelet_transforms.rs +++ b/src/types/wavelet_transforms.rs @@ -46,7 +46,7 @@ The coefficients for the individual wavelet families implemented by the library I. Daubechies. Orthonormal Bases of Compactly Supported Wavelets. Communications on Pure and Applied Mathematics, 41 (1988) 909–996. A. Cohen, I. Daubechies, and J.-C. Feauveau. Biorthogonal Bases of Compactly Supported Wavelets. Communications on Pure and Applied Mathematics, 45 (1992) 485–560. -The PhysioNet archive of physiological datasets can be found online at http://www.physionet.org/ and is described in the following paper, +The PhysioNet archive of physiological datasets can be found online at and is described in the following paper, Goldberger et al. PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220 2000. From ba37b0d85c0a1efaa38f0f6a2e3ffe333b28a75f Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Wed, 25 Dec 2024 23:44:27 +0100 Subject: [PATCH 17/28] Use the standard Complex type In doing so, remove many std::mem::transmute, all type unsafe conversions being handled in the "complex" module. --- examples/eigen_nonsymm.rs | 4 +- src/blas.rs | 339 +-- src/linear_algebra.rs | 31 +- src/polynomials.rs | 52 +- src/types/complex.rs | 2748 +++++++++++------------- src/types/eigen_symmetric_workspace.rs | 38 +- src/types/matrix_complex.rs | 24 +- src/types/mod.rs | 2 +- src/types/vector_complex.rs | 27 +- 9 files changed, 1575 insertions(+), 1690 deletions(-) diff --git a/examples/eigen_nonsymm.rs b/examples/eigen_nonsymm.rs index 606857ca..149552f5 100644 --- a/examples/eigen_nonsymm.rs +++ b/examples/eigen_nonsymm.rs @@ -29,13 +29,13 @@ fn main() { for i in 0..4 { let eval_i = eval.get(i); evec.column(i, |evec_i| { - println!("eigenvalue = {} + {}", eval_i.real(), eval_i.imaginary()); + println!("eigenvalue = {eval_i}"); evec_i.expect("Failed to get get column").vector(|v| { let v = v.expect("Failed to get vector from column"); println!("eigenvector = "); for j in 0..4 { let z = v.get(j); - println!("{} + {}", z.real(), z.imaginary()); + println!("{z}"); } }); }); diff --git a/src/blas.rs b/src/blas.rs index d92e40a6..9e612f39 100644 --- a/src/blas.rs +++ b/src/blas.rs @@ -4,8 +4,9 @@ pub mod level1 { use crate::ffi::FFI; - use crate::types::complex::CFFI; - use crate::{types, Value}; + use crate::{VectorF32, VectorF64}; + use crate::{types, types::complex::{ToC, FromC}, Value}; + use num_complex::Complex; /// This function computes the sum \alpha + x^T y for the vectors x and y, returning the result /// in result. @@ -61,10 +62,10 @@ pub mod level1 { pub fn cdotu( x: &types::VectorComplexF32, y: &types::VectorComplexF32, - ) -> Result { - let mut dotu = types::ComplexF32::default().unwrap(); + ) -> Result, Value> { + let mut dotu = Complex::::default().unwrap(); let ret = unsafe { sys::gsl_blas_cdotu(x.unwrap_shared(), y.unwrap_shared(), &mut dotu) }; - result_handler!(ret, types::ComplexF32::wrap(dotu)) + result_handler!(ret, dotu.wrap()) } /// This function computes the complex scalar product x^T y for the vectors x and y, returning @@ -75,10 +76,10 @@ pub mod level1 { pub fn zdotu( x: &types::VectorComplexF64, y: &types::VectorComplexF64, - ) -> Result { - let mut dotu = types::ComplexF64::default().unwrap(); + ) -> Result, Value> { + let mut dotu = Complex::::default().unwrap(); let ret = unsafe { sys::gsl_blas_zdotu(x.unwrap_shared(), y.unwrap_shared(), &mut dotu) }; - result_handler!(ret, types::ComplexF64::wrap(dotu)) + result_handler!(ret, dotu.wrap()) } /// This function computes the complex conjugate scalar product x^H y for the vectors x and y, @@ -89,10 +90,10 @@ pub mod level1 { pub fn cdotc( x: &types::VectorComplexF32, y: &types::VectorComplexF32, - ) -> Result { - let mut dotc = types::ComplexF32::default().unwrap(); + ) -> Result, Value> { + let mut dotc = Complex::::default().unwrap(); let ret = unsafe { sys::gsl_blas_cdotc(x.unwrap_shared(), y.unwrap_shared(), &mut dotc) }; - result_handler!(ret, types::ComplexF32::wrap(dotc)) + result_handler!(ret, dotc.wrap()) } /// This function computes the complex conjugate scalar product x^H y for the vectors x and y, @@ -103,10 +104,10 @@ pub mod level1 { pub fn zdotc( x: &types::VectorComplexF64, y: &types::VectorComplexF64, - ) -> Result { - let mut dotc = types::ComplexF64::default().unwrap(); + ) -> Result, Value> { + let mut dotc = Complex::::default().unwrap(); let ret = unsafe { sys::gsl_blas_zdotc(x.unwrap_shared(), y.unwrap_shared(), &mut dotc) }; - result_handler!(ret, types::ComplexF64::wrap(dotc)) + result_handler!(ret, dotc.wrap()) } /// This function computes the Euclidean norm ||x||_2 = \sqrt {\sum x_i^2} of the vector x. @@ -278,13 +279,13 @@ pub mod level1 { /// This function computes the sum y = \alpha x + y for the vectors x and y. #[doc(alias = "gsl_blas_caxpy")] pub fn caxpy( - alpha: &types::ComplexF32, + alpha: &Complex, x: &types::VectorComplexF32, y: &mut types::VectorComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_caxpy( - std::mem::transmute(*alpha), + alpha.unwrap(), x.unwrap_shared(), y.unwrap_unique(), ) @@ -295,13 +296,13 @@ pub mod level1 { /// This function computes the sum y = \alpha x + y for the vectors x and y. #[doc(alias = "gsl_blas_zaxpy")] pub fn zaxpy( - alpha: &types::ComplexF64, + alpha: &Complex, x: &types::VectorComplexF64, y: &mut types::VectorComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zaxpy( - std::mem::transmute(*alpha), + alpha.unwrap(), x.unwrap_shared(), y.unwrap_unique(), ) @@ -323,14 +324,14 @@ pub mod level1 { /// This function rescales the vector x by the multiplicative factor alpha. #[doc(alias = "gsl_blas_cscal")] - pub fn cscal(alpha: &types::ComplexF32, x: &mut types::VectorComplexF32) { - unsafe { sys::gsl_blas_cscal(std::mem::transmute(*alpha), x.unwrap_unique()) } + pub fn cscal(alpha: &Complex, x: &mut types::VectorComplexF32) { + unsafe { sys::gsl_blas_cscal(alpha.unwrap(), x.unwrap_unique()) } } /// This function rescales the vector x by the multiplicative factor alpha. #[doc(alias = "gsl_blas_zscal")] - pub fn zscal(alpha: &types::ComplexF64, x: &mut types::VectorComplexF64) { - unsafe { sys::gsl_blas_zscal(std::mem::transmute(*alpha), x.unwrap_unique()) } + pub fn zscal(alpha: &Complex, x: &mut types::VectorComplexF64) { + unsafe { sys::gsl_blas_zscal(alpha.unwrap(), x.unwrap_unique()) } } /// This function rescales the vector x by the multiplicative factor alpha. @@ -349,44 +350,46 @@ pub mod level1 { /// /// ```text /// [ c s ] [ a ] = [ r ] - /// /// [ -s c ] [ b ] [ 0 ] /// ``` /// - /// The variables a and b are overwritten by the routine. + /// Return `(c, s, r)`. #[doc(alias = "gsl_blas_srotg")] - pub fn srotg(a: &mut [f32], b: &mut [f32], c: &mut [f32], d: &mut [f32]) -> Result<(), Value> { + pub fn srotg(mut a: f32, mut b: f32) -> Result<(f32, f32, f32), Value> { + let mut c = 0.; + let mut s = 0.; let ret = unsafe { sys::gsl_blas_srotg( - a.as_mut_ptr(), - b.as_mut_ptr(), - c.as_mut_ptr(), - d.as_mut_ptr(), + &mut a, + &mut b, + &mut c, + &mut s, ) }; - result_handler!(ret, ()) + result_handler!(ret, (c, s, a)) } /// This function computes a Givens rotation (c,s) which zeroes the vector (a,b), /// /// ```text /// [ c s ] [ a ] = [ r ] - /// /// [ -s c ] [ b ] [ 0 ] /// ``` /// - /// The variables a and b are overwritten by the routine. + /// Return `(c, s, r)`. #[doc(alias = "gsl_blas_drotg")] - pub fn drotg(a: &mut [f64], b: &mut [f64], c: &mut [f64], d: &mut [f64]) -> Result<(), Value> { + pub fn drotg(mut a: f64, mut b: f64) -> Result<(f64, f64, f64), Value> { + let mut c = 0.; + let mut s = 0.; let ret = unsafe { sys::gsl_blas_drotg( - a.as_mut_ptr(), - b.as_mut_ptr(), - c.as_mut_ptr(), - d.as_mut_ptr(), + &mut a, + &mut b, + &mut c, + &mut s, ) }; - result_handler!(ret, ()) + result_handler!(ret, (c, s, a)) } /// This function applies a Givens rotation (x', y') = (c x + s y, -s x + c y) to the vectors x, y. @@ -413,48 +416,50 @@ pub mod level1 { result_handler!(ret, ()) } - /// This function computes a modified Givens transformation. - /// The modified Givens transformation is defined in the original Level-1 BLAS specification, given in the references. + /// Return a modified Givens transformation. + /// The modified Givens transformation is defined in the original + /// [Level-1 BLAS specification](https://help.imsl.com/fortran/fnlmath/current/basic-linear-algebra-sub.htm#mch9_1817247609_srotmg). #[doc(alias = "gsl_blas_srotmg")] pub fn srotmg( - d1: &mut [f32], - d2: &mut [f32], - b1: &mut [f32], + mut d1: f32, + mut d2: f32, + mut b1: f32, b2: f32, - P: &mut [f32], - ) -> Result<(), Value> { + ) -> Result<[f32; 5], Value> { + let mut p = [f32::NAN; 5]; let ret = unsafe { sys::gsl_blas_srotmg( - d1.as_mut_ptr(), - d2.as_mut_ptr(), - b1.as_mut_ptr(), + &mut d1, + &mut d2, + &mut b1, b2, - P.as_mut_ptr(), + p.as_mut_ptr(), ) }; - result_handler!(ret, ()) + result_handler!(ret, p) } - /// This function computes a modified Givens transformation. - /// The modified Givens transformation is defined in the original Level-1 BLAS specification, given in the references. + /// Return a modified Givens transformation. + /// The modified Givens transformation is defined in the original + /// [Level-1 BLAS specification](https://help.imsl.com/fortran/fnlmath/current/basic-linear-algebra-sub.htm#mch9_1817247609_srotmg). #[doc(alias = "gsl_blas_drotmg")] pub fn drotmg( - d1: &mut [f64], - d2: &mut [f64], - b1: &mut [f64], + mut d1: f64, + mut d2: f64, + mut b1: f64, b2: f64, - P: &mut [f64], - ) -> Result<(), Value> { + ) -> Result<[f64; 5], Value> { + let mut p = [f64::NAN; 5]; let ret = unsafe { sys::gsl_blas_drotmg( - d1.as_mut_ptr(), - d2.as_mut_ptr(), - b1.as_mut_ptr(), + &mut d1, + &mut d2, + &mut b1, b2, - P.as_mut_ptr(), + p.as_mut_ptr(), ) }; - result_handler!(ret, ()) + result_handler!(ret, p) } /// This function applies a modified Givens transformation. @@ -462,10 +467,18 @@ pub mod level1 { pub fn srotm( x: &mut types::VectorF32, y: &mut types::VectorF32, - P: &mut [f32], + p: [f32; 5], ) -> Result<(), Value> { + let lenx = VectorF32::len(x); + let leny = VectorF32::len(y); + if lenx != leny { + panic!("rgsl::blas::srotm: len(x) = {lenx} != len(y) = {leny}") + } let ret = - unsafe { sys::gsl_blas_srotm(x.unwrap_unique(), y.unwrap_unique(), P.as_mut_ptr()) }; + unsafe { sys::gsl_blas_srotm( + x.unwrap_unique(), + y.unwrap_unique(), + p.as_ptr()) }; result_handler!(ret, ()) } @@ -474,17 +487,48 @@ pub mod level1 { pub fn drotm( x: &mut types::VectorF64, y: &mut types::VectorF64, - P: &mut [f64], + p: [f64; 5], ) -> Result<(), Value> { + let lenx = VectorF64::len(x); + let leny = VectorF64::len(y); + if lenx != leny { + panic!("rgsl::blas::drotm: len(x) = {lenx} != len(y) = {leny}") + } let ret = - unsafe { sys::gsl_blas_drotm(x.unwrap_unique(), y.unwrap_unique(), P.as_mut_ptr()) }; + unsafe { sys::gsl_blas_drotm( + x.unwrap_unique(), + y.unwrap_unique(), + p.as_ptr()) }; result_handler!(ret, ()) } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_srotg() { + let (c, s, r) = srotg(3., 4.).unwrap(); + assert_eq!(c, 0.6); + assert_eq!(s, 0.8); + assert_eq!(r, 5.); + } + + #[test] + fn test_drotg() { + let (c, s, r) = drotg(3., 4.).unwrap(); + assert_eq!(c, 0.6); + assert_eq!(s, 0.8); + assert_eq!(r, 5.); + } + } } pub mod level2 { + use crate::complex::ToC; use crate::ffi::FFI; use crate::{enums, types, Value}; + use num_complex::Complex; /// This function computes the matrix-vector product and sum y = \alpha op(A) x + \beta y, where op(A) = A, A^T, A^H for TransA = CblasNoTrans, CblasTrans, CblasConjTrans. #[doc(alias = "gsl_blas_sgemv")] @@ -536,19 +580,19 @@ pub mod level2 { #[doc(alias = "gsl_blas_cgemv")] pub fn cgemv( transA: enums::CblasTranspose, - alpha: &types::ComplexF32, + alpha: Complex, A: &types::MatrixComplexF32, x: &types::VectorComplexF32, - beta: &types::ComplexF32, + beta: &Complex, y: &mut types::VectorComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_cgemv( transA.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), x.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), y.unwrap_unique(), ) }; @@ -559,19 +603,19 @@ pub mod level2 { #[doc(alias = "gsl_blas_zgemv")] pub fn zgemv( transA: enums::CblasTranspose, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, x: &types::VectorComplexF64, - beta: &types::ComplexF64, + beta: &Complex, y: &mut types::VectorComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zgemv( transA.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), x.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), y.unwrap_unique(), ) }; @@ -818,19 +862,19 @@ pub mod level2 { #[doc(alias = "gsl_blas_chemv")] pub fn chemv( uplo: enums::CblasUplo, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, x: &types::VectorComplexF32, - beta: &types::ComplexF32, + beta: &Complex, y: &mut types::VectorComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_chemv( uplo.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), x.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), y.unwrap_unique(), ) }; @@ -843,19 +887,19 @@ pub mod level2 { #[doc(alias = "gsl_blas_zhemv")] pub fn zhemv( uplo: enums::CblasUplo, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, x: &types::VectorComplexF64, - beta: &types::ComplexF64, + beta: &Complex, y: &mut types::VectorComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zhemv( uplo.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), x.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), y.unwrap_unique(), ) }; @@ -903,14 +947,14 @@ pub mod level2 { /// This function computes the rank-1 update A = \alpha x y^T + A of the matrix A. #[doc(alias = "gsl_blas_cgeru")] pub fn cgeru( - alpha: &types::ComplexF32, + alpha: &Complex, x: &types::VectorComplexF32, y: &types::VectorComplexF32, A: &mut types::MatrixComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_cgeru( - std::mem::transmute(*alpha), + alpha.unwrap(), x.unwrap_shared(), y.unwrap_shared(), A.unwrap_unique(), @@ -922,14 +966,14 @@ pub mod level2 { /// This function computes the rank-1 update A = \alpha x y^T + A of the matrix A. #[doc(alias = "gsl_blas_zgeru")] pub fn zgeru( - alpha: &types::ComplexF64, + alpha: &Complex, x: &types::VectorComplexF64, y: &types::VectorComplexF64, A: &mut types::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zgeru( - std::mem::transmute(*alpha), + alpha.unwrap(), x.unwrap_shared(), y.unwrap_shared(), A.unwrap_unique(), @@ -941,14 +985,14 @@ pub mod level2 { /// This function computes the conjugate rank-1 update A = \alpha x y^H + A of the matrix A. #[doc(alias = "gsl_blas_cgerc")] pub fn cgerc( - alpha: &types::ComplexF32, + alpha: &Complex, x: &types::VectorComplexF32, y: &types::VectorComplexF32, A: &mut types::MatrixComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_cgerc( - std::mem::transmute(*alpha), + alpha.unwrap(), x.unwrap_shared(), y.unwrap_shared(), A.unwrap_unique(), @@ -960,14 +1004,14 @@ pub mod level2 { /// This function computes the conjugate rank-1 update A = \alpha x y^H + A of the matrix A. #[doc(alias = "gsl_blas_zgerc")] pub fn zgerc( - alpha: &types::ComplexF64, + alpha: &Complex, x: &types::VectorComplexF64, y: &types::VectorComplexF64, A: &mut types::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zgerc( - std::mem::transmute(*alpha), + alpha.unwrap(), x.unwrap_shared(), y.unwrap_shared(), A.unwrap_unique(), @@ -1089,7 +1133,7 @@ pub mod level2 { #[doc(alias = "gsl_blas_cher2")] pub fn cher2( uplo: enums::CblasUplo, - alpha: &types::ComplexF32, + alpha: &Complex, x: &types::VectorComplexF32, y: &types::VectorComplexF32, A: &mut types::MatrixComplexF32, @@ -1097,7 +1141,7 @@ pub mod level2 { let ret = unsafe { sys::gsl_blas_cher2( uplo.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), x.unwrap_shared(), y.unwrap_shared(), A.unwrap_unique(), @@ -1113,7 +1157,7 @@ pub mod level2 { #[doc(alias = "gsl_blas_zher2")] pub fn zher2( uplo: enums::CblasUplo, - alpha: &types::ComplexF64, + alpha: &Complex, x: &types::VectorComplexF64, y: &types::VectorComplexF64, A: &mut types::MatrixComplexF64, @@ -1121,7 +1165,7 @@ pub mod level2 { let ret = unsafe { sys::gsl_blas_zher2( uplo.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), x.unwrap_shared(), y.unwrap_shared(), A.unwrap_unique(), @@ -1133,7 +1177,8 @@ pub mod level2 { pub mod level3 { use crate::ffi::FFI; - use crate::{enums, types, Value}; + use crate::{complex::ToC, enums, types, Value}; + use num_complex::Complex; /// This function computes the matrix-matrix product and sum C = \alpha op(A) op(B) + \beta C where op(A) = A, A^T, A^H for TransA = CblasNoTrans, CblasTrans, CblasConjTrans and similarly for the parameter TransB. #[doc(alias = "gsl_blas_sgemm")] @@ -1190,20 +1235,20 @@ pub mod level3 { pub fn cgemm( transA: enums::CblasTranspose, transB: enums::CblasTranspose, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, B: &types::MatrixComplexF32, - beta: &types::ComplexF32, + beta: &Complex, C: &mut types::MatrixComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_cgemm( transA.into(), transB.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1215,20 +1260,20 @@ pub mod level3 { pub fn zgemm( transA: enums::CblasTranspose, transB: enums::CblasTranspose, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, B: &types::MatrixComplexF64, - beta: &types::ComplexF64, + beta: &Complex, C: &mut types::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zgemm( transA.into(), transB.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1293,20 +1338,20 @@ pub mod level3 { pub fn csymm( side: enums::CblasSide, uplo: enums::CblasUplo, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, B: &types::MatrixComplexF32, - beta: &types::ComplexF32, + beta: &Complex, C: &mut types::MatrixComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_csymm( side.into(), uplo.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1319,20 +1364,20 @@ pub mod level3 { pub fn zsymm( side: enums::CblasSide, uplo: enums::CblasUplo, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, B: &types::MatrixComplexF64, - beta: &types::ComplexF64, + beta: &Complex, C: &mut types::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zsymm( side.into(), uplo.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1346,20 +1391,20 @@ pub mod level3 { pub fn chemm( side: enums::CblasSide, uplo: enums::CblasUplo, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, B: &types::MatrixComplexF32, - beta: &types::ComplexF32, + beta: &Complex, C: &mut types::MatrixComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_chemm( side.into(), uplo.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1373,20 +1418,20 @@ pub mod level3 { pub fn zhemm( side: enums::CblasSide, uplo: enums::CblasUplo, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, B: &types::MatrixComplexF64, - beta: &types::ComplexF64, + beta: &Complex, C: &mut types::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zhemm( side.into(), uplo.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1459,7 +1504,7 @@ pub mod level3 { uplo: enums::CblasUplo, transA: enums::CblasTranspose, diag: enums::CblasDiag, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, B: &mut types::MatrixComplexF32, ) -> Result<(), Value> { @@ -1469,7 +1514,7 @@ pub mod level3 { uplo.into(), transA.into(), diag.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_unique(), ) @@ -1487,7 +1532,7 @@ pub mod level3 { uplo: enums::CblasUplo, transA: enums::CblasTranspose, diag: enums::CblasDiag, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, B: &mut types::MatrixComplexF64, ) -> Result<(), Value> { @@ -1497,7 +1542,7 @@ pub mod level3 { uplo.into(), transA.into(), diag.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_unique(), ) @@ -1571,7 +1616,7 @@ pub mod level3 { uplo: enums::CblasUplo, transA: enums::CblasTranspose, diag: enums::CblasDiag, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, B: &mut types::MatrixComplexF32, ) -> Result<(), Value> { @@ -1581,7 +1626,7 @@ pub mod level3 { uplo.into(), transA.into(), diag.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_unique(), ) @@ -1599,7 +1644,7 @@ pub mod level3 { uplo: enums::CblasUplo, transA: enums::CblasTranspose, diag: enums::CblasDiag, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, B: &mut types::MatrixComplexF64, ) -> Result<(), Value> { @@ -1609,7 +1654,7 @@ pub mod level3 { uplo.into(), transA.into(), diag.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_unique(), ) @@ -1674,18 +1719,18 @@ pub mod level3 { pub fn csyrk( uplo: enums::CblasUplo, trans: enums::CblasTranspose, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, - beta: &types::ComplexF32, + beta: &Complex, C: &mut types::MatrixComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_csyrk( uplo.into(), trans.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1699,18 +1744,18 @@ pub mod level3 { pub fn zsyrk( uplo: enums::CblasUplo, trans: enums::CblasTranspose, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, - beta: &types::ComplexF64, + beta: &Complex, C: &mut types::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zsyrk( uplo.into(), trans.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1830,20 +1875,20 @@ pub mod level3 { pub fn csyr2k( uplo: enums::CblasUplo, trans: enums::CblasTranspose, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, B: &types::MatrixComplexF32, - beta: &types::ComplexF32, + beta: &Complex, C: &mut types::MatrixComplexF32, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_csyr2k( uplo.into(), trans.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1857,20 +1902,20 @@ pub mod level3 { pub fn zsyr2k( uplo: enums::CblasUplo, trans: enums::CblasTranspose, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, B: &types::MatrixComplexF64, - beta: &types::ComplexF64, + beta: &Complex, C: &mut types::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_blas_zsyr2k( uplo.into(), trans.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), - std::mem::transmute(*beta), + beta.unwrap(), C.unwrap_unique(), ) }; @@ -1885,7 +1930,7 @@ pub mod level3 { pub fn cher2k( uplo: enums::CblasUplo, trans: enums::CblasTranspose, - alpha: &types::ComplexF32, + alpha: &Complex, A: &types::MatrixComplexF32, B: &types::MatrixComplexF32, beta: f32, @@ -1895,7 +1940,7 @@ pub mod level3 { sys::gsl_blas_cher2k( uplo.into(), trans.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), beta, @@ -1913,7 +1958,7 @@ pub mod level3 { pub fn zher2k( uplo: enums::CblasUplo, trans: enums::CblasTranspose, - alpha: &types::ComplexF64, + alpha: &Complex, A: &types::MatrixComplexF64, B: &types::MatrixComplexF64, beta: f64, @@ -1923,7 +1968,7 @@ pub mod level3 { sys::gsl_blas_zher2k( uplo.into(), trans.into(), - std::mem::transmute(*alpha), + alpha.unwrap(), A.unwrap_shared(), B.unwrap_shared(), beta, diff --git a/src/linear_algebra.rs b/src/linear_algebra.rs index 2f8497c7..2646f40e 100644 --- a/src/linear_algebra.rs +++ b/src/linear_algebra.rs @@ -161,11 +161,14 @@ James Demmel, Krešimir Veselić, “Jacobi’s Method is more accurate than QR in the lawns or lawnspdf directories. !*/ +use crate::complex::ToC; use crate::enums; use crate::ffi::FFI; -use crate::Value; - -use crate::types::complex::FFFI; +use crate::{ + complex::FromC, + Value, +}; +use num_complex::Complex; /// Factorise a general N x N matrix A into, /// @@ -381,7 +384,7 @@ pub fn LU_det(lu: &mut crate::MatrixF64, signum: i32) -> f64 { /// This function computes the determinant of a matrix A from its LU decomposition, LU. The determinant is computed as the product of the /// diagonal elements of U and the sign of the row permutation signum. #[doc(alias = "gsl_linalg_complex_LU_det")] -pub fn complex_LU_det(lu: &mut crate::MatrixComplexF64, signum: i32) -> crate::ComplexF64 { +pub fn complex_LU_det(lu: &mut crate::MatrixComplexF64, signum: i32) -> Complex { unsafe { sys::gsl_linalg_complex_LU_det(lu.unwrap_unique(), signum).wrap() } } @@ -406,7 +409,7 @@ pub fn LU_sgndet(lu: &mut crate::MatrixF64, signum: i32) -> i32 { /// This function computes the sign or phase factor of the determinant of a matrix A, \det(A)/|\det(A)|, from its LU decomposition, LU. #[doc(alias = "gsl_linalg_complex_LU_sgndet")] -pub fn complex_LU_sgndet(lu: &mut crate::MatrixComplexF64, signum: i32) -> crate::ComplexF64 { +pub fn complex_LU_sgndet(lu: &mut crate::MatrixComplexF64, signum: i32) -> Complex { unsafe { sys::gsl_linalg_complex_LU_sgndet(lu.unwrap_unique(), signum).wrap() } } @@ -1271,11 +1274,11 @@ pub fn householder_transform(v: &mut crate::VectorF64) -> f64 { /// This function prepares a Householder transformation P = I - \tau v v^T which can be used to zero all the elements of the input vector except /// the first. On output the transformation is stored in the vector v and the scalar \tau is returned. #[doc(alias = "gsl_linalg_complex_householder_transform")] -pub fn complex_householder_transform(v: &mut crate::VectorComplexF64) -> crate::ComplexF64 { +pub fn complex_householder_transform(v: &mut crate::VectorComplexF64) -> Complex { unsafe { - std::mem::transmute(sys::gsl_linalg_complex_householder_transform( + sys::gsl_linalg_complex_householder_transform( v.unwrap_unique(), - )) + ).wrap() } } @@ -1295,13 +1298,13 @@ pub fn householder_hm( /// the result P A is stored in A. #[doc(alias = "gsl_linalg_complex_householder_hm")] pub fn complex_householder_hm( - tau: &crate::ComplexF64, + tau: &Complex, v: &crate::VectorComplexF64, a: &mut crate::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_linalg_complex_householder_hm( - std::mem::transmute(*tau), + tau.unwrap(), v.unwrap_shared(), a.unwrap_unique(), ) @@ -1325,13 +1328,13 @@ pub fn householder_mh( /// the result A P is stored in A. #[doc(alias = "gsl_linalg_complex_householder_mh")] pub fn complex_householder_mh( - tau: &crate::ComplexF64, + tau: &Complex, v: &crate::VectorComplexF64, a: &mut crate::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_linalg_complex_householder_mh( - std::mem::transmute(*tau), + tau.unwrap(), v.unwrap_shared(), a.unwrap_unique(), ) @@ -1355,13 +1358,13 @@ pub fn householder_hv( /// w is stored in w. #[doc(alias = "gsl_linalg_complex_householder_hv")] pub fn complex_householder_hv( - tau: &crate::ComplexF64, + tau: &Complex, v: &crate::VectorComplexF64, w: &mut crate::VectorComplexF64, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_linalg_complex_householder_hv( - std::mem::transmute(*tau), + tau.unwrap(), v.unwrap_shared(), w.unwrap_unique(), ) diff --git a/src/polynomials.rs b/src/polynomials.rs index c2d44d45..d43e0837 100644 --- a/src/polynomials.rs +++ b/src/polynomials.rs @@ -26,9 +26,11 @@ R. L. Burden and J. D. Faires, Numerical Analysis, 9th edition, ISBN 0-538-73351 /// `P(x) = c[0] + c[1] x + c[2] x^2 + \dots + c[len-1] x^{len-1}` using Horner’s method for /// stability. pub mod evaluation { - use crate::types::ComplexF64; - use crate::Value; - use std::mem::transmute; + use crate::{ + types::complex::{ToC, FromC}, + Value, + }; + use num_complex::Complex; /// This function evaluates a polynomial with real coefficients for the real variable x. #[doc(alias = "gsl_poly_eval")] @@ -38,30 +40,31 @@ pub mod evaluation { /// This function evaluates a polynomial with real coefficients for the complex variable z. #[doc(alias = "gsl_poly_complex_eval")] - pub fn poly_complex_eval(c: &[f64], z: &ComplexF64) -> ComplexF64 { + pub fn poly_complex_eval(c: &[f64], z: &Complex) -> Complex { unsafe { - transmute(sys::gsl_poly_complex_eval( + sys::gsl_poly_complex_eval( c.as_ptr(), c.len() as i32, - transmute(*z), - )) + z.unwrap(), + ).wrap() } } /// This function evaluates a polynomial with complex coefficients for the complex variable z. #[doc(alias = "gsl_complex_poly_complex_eval")] - pub fn complex_poly_complex_eval(c: &[ComplexF64], z: &ComplexF64) -> ComplexF64 { + pub fn complex_poly_complex_eval(c: &[Complex], z: &Complex) -> Complex { + // FIXME: Making a copy should be unnecessary. let mut tmp = Vec::new(); for it in c.iter() { - unsafe { tmp.push(transmute(*it)) }; + tmp.push(it.unwrap()) } unsafe { - transmute(sys::gsl_complex_poly_complex_eval( + sys::gsl_complex_poly_complex_eval( tmp.as_ptr(), tmp.len() as i32, - transmute(*z), - )) + z.unwrap(), + ).wrap() } } @@ -171,8 +174,8 @@ pub mod divided_difference_representation { } pub mod quadratic_equations { - use crate::types::ComplexF64; use crate::Value; + use num_complex::Complex; use std::mem::transmute; /// This function finds the real roots of the quadratic equation, @@ -214,18 +217,21 @@ pub mod quadratic_equations { a: f64, b: f64, c: f64, - z0: &mut ComplexF64, - z1: &mut ComplexF64, + z0: &mut Complex, + z1: &mut Complex, ) -> Result<(), Value> { let ret = - unsafe { sys::gsl_poly_complex_solve_quadratic(a, b, c, transmute(z0), transmute(z1)) }; + unsafe { sys::gsl_poly_complex_solve_quadratic( + a, b, c, + transmute(z0), + transmute(z1)) }; result_handler!(ret, ()) } } pub mod cubic_equations { - use crate::types::ComplexF64; use crate::Value; + use num_complex::Complex; use std::mem::transmute; /// This function finds the real roots of the cubic equation, @@ -263,12 +269,16 @@ pub mod cubic_equations { a: f64, b: f64, c: f64, - z0: &mut ComplexF64, - z1: &mut ComplexF64, - z2: &mut ComplexF64, + z0: &mut Complex, + z1: &mut Complex, + z2: &mut Complex, ) -> Result<(), Value> { let ret = unsafe { - sys::gsl_poly_complex_solve_cubic(a, b, c, transmute(z0), transmute(z1), transmute(z2)) + sys::gsl_poly_complex_solve_cubic( + a, b, c, + transmute(z0), + transmute(z1), + transmute(z2)) }; result_handler!(ret, ()) } diff --git a/src/types/complex.rs b/src/types/complex.rs index 2d021b36..ecc2d208 100644 --- a/src/types/complex.rs +++ b/src/types/complex.rs @@ -2,1797 +2,1611 @@ // A rust binding for the GSL library by Guillaume Gomez (guillaume1.gomez@gmail.com) // -// TODO : port to Rust type : http://doc.rust-lang.org/num/complex/struct.Complex.html +use num_complex::Complex; -use std::fmt::{self, Debug, Formatter}; +#[deprecated(since="8.0.0", note="use `Complex` instead")] +pub type ComplexF64 = Complex; +#[deprecated(since="8.0.0", note="use `Complex` instead")] +pub type ComplexF32 = Complex; -#[doc(hidden)] -#[allow(clippy::upper_case_acronyms)] -pub trait CFFI { - fn wrap(s: T) -> Self; +pub(crate) trait ToC { fn unwrap(self) -> T; } -#[doc(hidden)] -#[allow(clippy::upper_case_acronyms)] -pub trait FFFI { +pub(crate) trait FromC { fn wrap(self) -> T; - fn unwrap(t: T) -> Self; } -//#[deprecated(note = "Use `Complex64` from the `num_complex` create instead")] -#[repr(C)] -#[derive(Clone, Copy, PartialEq)] -pub struct ComplexF64 { - pub dat: [f64; 2], +impl ToC for Complex { + fn unwrap(self) -> sys::gsl_complex { + // Complex is memory layout compatible with [T; 2] + unsafe { std::mem::transmute(self) } + } } -impl ComplexF64 { - /// This function uses the rectangular Cartesian components (x,y) to return the complex number - /// z = x + i y. - #[doc(alias = "gsl_complex_rect")] - pub fn rect(x: f64, y: f64) -> ComplexF64 { - unsafe { sys::gsl_complex_rect(x, y).wrap() } +impl FromC> for sys::gsl_complex { + fn wrap(self) -> Complex { + unsafe { std::mem::transmute(self) } } +} - /// This function returns the complex number z = r \exp(i \theta) = r (\cos(\theta) + i - /// \sin(\theta)) from the polar representation (r,theta). - #[doc(alias = "gsl_complex_polar")] - pub fn polar(r: f64, theta: f64) -> ComplexF64 { - unsafe { sys::gsl_complex_polar(r, theta).wrap() } - } +/// Define the capabilities provided for Complex numbers by the GSL. +pub trait ComplexOps { + /// This function uses the rectangular Cartesian components (x,y) + /// to return the complex number z = x + i y. + #[doc(alias = "gsl_complex_rect")] + fn rect(x: T, y: T) -> Complex; - /// This function returns the argument of the complex number z, \arg(z), where -\pi < \arg(z) - /// <= \pi. - #[doc(alias = "gsl_complex_arg")] - pub fn arg(&self) -> f64 { - unsafe { sys::gsl_complex_arg(self.unwrap()) } - } + /// This function returns the complex number z = r \exp(iθ) = r + /// (\cos(θ) + i \sin(θ)) from the polar representation (`r`, `theta`). + #[doc(alias = "gsl_complex_polar")] + fn polar(r: T, theta: T) -> Complex; /// This function returns the magnitude of the complex number z, |z|. #[doc(alias = "gsl_complex_abs")] - pub fn abs(&self) -> f64 { - unsafe { sys::gsl_complex_abs(self.unwrap()) } - } + #[deprecated(since="8.0.0", note="please use `.norm()` instead")] + fn abs(&self) -> T; - /// This function returns the squared magnitude of the complex number z, |z|^2. + /// This function returns the squared magnitude of the complex + /// number z = `self`, |z|². #[doc(alias = "gsl_complex_abs2")] - pub fn abs2(&self) -> f64 { - unsafe { sys::gsl_complex_abs2(self.unwrap()) } - } + #[deprecated(since="8.0.0", note="please use `.norm_sqr()` instead")] + fn abs2(&self) -> T; - /// This function returns the natural logarithm of the magnitude of the complex number z, - /// \log|z|. + /// This function returns the natural logarithm of the magnitude + /// of the complex number z = `self`, log|z|. /// /// It allows an accurate evaluation of \log|z| when |z| is close to one. /// - /// The direct evaluation of log(gsl_complex_abs(z)) would lead to a loss of precision in this - /// case. + /// The direct evaluation of log(gsl_complex_abs(z)) would lead to + /// a loss of precision in this case. #[doc(alias = "gsl_complex_logabs")] - pub fn logabs(&self) -> f64 { - unsafe { sys::gsl_complex_logabs(self.unwrap()) } - } + fn logabs(&self) -> T; /// This function returns the sum of the complex numbers a and b, z=a+b. #[doc(alias = "gsl_complex_add")] - pub fn add(&self, other: &ComplexF64) -> ComplexF64 { + #[deprecated(since="8.0.0", note="please use `+` instead")] + fn add(&self, other: &Complex) -> Complex; + + /// This function returns the difference of the complex numbers a + /// and b, z=a-b. + #[doc(alias = "gsl_complex_sub")] + #[deprecated(since="8.0.0", note="please use `-` instead")] + fn sub(&self, other: &Complex) -> Complex; + + /// This function returns the product of the complex numbers a and b, z=ab. + #[doc(alias = "gsl_complex_mul")] + #[deprecated(since="8.0.0", note="please use `*` instead")] + fn mul(&self, other: &Complex) -> Complex; + + /// This function returns the quotient of the complex numbers a + /// and b, z=a/b. + #[doc(alias = "gsl_complex_div")] + #[deprecated(since="8.0.0", note="please use `/` of `fdiv` instead")] + fn div(&self, other: &Complex) -> Complex; + + /// This function returns the sum of the complex number a and the + /// real number x, z = a + x. + #[doc(alias = "gsl_complex_add_real")] + #[deprecated(since="8.0.0", note="please use `+` instead")] + fn add_real(&self, x: T) -> Complex; + + /// This function returns the difference of the complex number a + /// and the real number x, z=a-x. + #[doc(alias = "gsl_complex_sub_real")] + #[deprecated(since="8.0.0", note="please use `-` instead")] + fn sub_real(&self, x: T) -> Complex; + + /// This function returns the product of the complex number a and + /// the real number x, z=ax. + #[doc(alias = "gsl_complex_mul_real")] + #[deprecated(since="8.0.0", note="please use `*` instead")] + fn mul_real(&self, x: T) -> Complex; + + /// This function returns the quotient of the complex number a and + /// the real number x, z=a/x. + #[doc(alias = "gsl_complex_div_real")] + #[deprecated(since="8.0.0", note="please use `/` instead")] + fn div_real(&self, x: T) -> Complex; + + /// This function returns the sum of the complex number a and the + /// imaginary number iy, z=a+iy. + #[doc(alias = "gsl_complex_add_imag")] + #[deprecated(since="8.0.0", note="please use `self + x * Complex::I` instead")] + fn add_imag(&self, x: T) -> Complex; + + /// This function returns the difference of the complex number a + /// and the imaginary number iy, z=a-iy. + #[doc(alias = "gsl_complex_sub_imag")] + #[deprecated(since="8.0.0", note="please use `self - x * Complex::I` instead")] + fn sub_imag(&self, x: T) -> Complex; + + /// This function returns the product of the complex number a and + /// the imaginary number iy, z=a*(iy). + #[doc(alias = "gsl_complex_mul_imag")] + #[deprecated(since="8.0.0", note="please use `self * x * Complex::I` instead")] + fn mul_imag(&self, x: T) -> Complex; + + /// This function returns the quotient of the complex number a and + /// the imaginary number iy, z=a/(iy). + #[doc(alias = "gsl_complex_div_imag")] + #[deprecated(since="8.0.0", note="please use `self / (x * Complex::I)` instead")] + fn div_imag(&self, x: T) -> Complex; + + /// This function returns the complex conjugate of the complex + /// number z, z^* = x - i y. + #[doc(alias = "gsl_complex_conjugate")] + #[deprecated(since="8.0.0", note="please use `.conj()` instead")] + fn conjugate(&self) -> Complex; + + /// This function returns the inverse, or reciprocal, of the + /// complex number z, 1/z = (x - i y)/ (x^2 + y^2). + #[doc(alias = "gsl_complex_inverse")] + #[deprecated(since="8.0.0", note="please use `.inv()` instead")] + fn inverse(&self) -> Complex; + + /// This function returns the negative of the complex number z, -z + /// = (-x) + i(-y). + #[doc(alias = "gsl_complex_negative")] + #[deprecated(since="8.0.0", note="please use the unary `-` instead")] + fn negative(&self) -> Complex; + + /// This function returns the complex square root of the real + /// number x, where x may be negative. + #[doc(alias = "gsl_complex_sqrt_real")] + fn sqrt_real(x: T) -> Complex; + + /// The function returns the complex number z raised to the + /// complex power a z^a. This is computed as \exp(\log(z)*a) + /// using complex logarithms and complex exponentials. + #[doc(alias = "gsl_complex_pow")] + #[deprecated(since="8.0.0", note="please use the unary `-` instead")] + fn pow(&self, other: &Complex) -> Complex; + + /// This function returns the complex number z raised to the real + /// power x, z^x. + #[doc(alias = "gsl_complex_pow_real")] + #[deprecated(since="8.0.0", note="please use `.powf(x)` instead")] + fn pow_real(&self, x: T) -> Complex; + + /// This function returns the complex base-b logarithm of the + /// complex number z, \log_b(z). This quantity is computed as the + /// ratio \log(z)/\log(b). + #[doc(alias = "gsl_complex_log_b")] + #[deprecated(since="8.0.0", note="please use `.log(base)` instead")] + fn log_b(&self, base: &Complex) -> Complex; + + /// This function returns the complex secant of the complex number + /// z, \sec(z) = 1/\cos(z). + #[doc(alias = "gsl_complex_sec")] + fn sec(&self) -> Complex; + + /// This function returns the complex cosecant of the complex + /// number z, \csc(z) = 1/\sin(z). + #[doc(alias = "gsl_complex_csc")] + fn csc(&self) -> Complex; + + /// This function returns the complex cotangent of the complex + /// number z, \cot(z) = 1/\tan(z). + #[doc(alias = "gsl_complex_cot")] + fn cot(&self) -> Complex; + + /// This function returns the complex arcsine of the complex + /// number z, \arcsin(z). The branch cuts are on the real axis, + /// less than -1 and greater than 1. + #[doc(alias = "gsl_complex_arcsin")] + #[deprecated(since="8.0.0", note="please use `.asin()` instead")] + fn arcsin(&self) -> Complex; + + /// This function returns the complex arcsine of the real number + /// z, \arcsin(z). + /// + /// * For z between -1 and 1, the function returns a real value in + /// the range \[-π/2,π/2\]. + /// * For z less than -1 the result has a real part of -π/2 and + /// a positive imaginary part. + /// * For z greater than 1 the result has a real part of π/2 and + /// a negative imaginary part. + #[doc(alias = "gsl_complex_arcsin_real")] + fn arcsin_real(z: T) -> Complex; + + /// This function returns the complex arccosine of the complex + /// number z, \arccos(z). The branch cuts are on the real axis, + /// less than -1 and greater than 1. + #[doc(alias = "gsl_complex_arccos")] + #[deprecated(since="8.0.0", note="please use `.acos()` instead")] + fn arccos(&self) -> Complex; + + /// This function returns the complex arccosine of the real number + /// z, \arccos(z). + /// + /// * For z between -1 and 1, the function returns a real value in + /// the range \[0,π\]. + /// * For z less than -1 the result has a real part of \pi and a + /// negative imaginary part. + /// * For z greater than 1 the result is purely imaginary and positive. + #[doc(alias = "gsl_complex_arccos_real")] + fn arccos_real(z: T) -> Complex; + + /// This function returns the complex arctangent of the complex + /// number z, \arctan(z). The branch cuts are on the imaginary + /// axis, below -i and above i. + #[doc(alias = "gsl_complex_arctan")] + #[deprecated(since="8.0.0", note="please use `.atan()` instead")] + fn arctan(&self) -> Complex; + + /// This function returns the complex arcsecant of the complex + /// number z, \arcsec(z) = \arccos(1/z). + #[doc(alias = "gsl_complex_arcsec")] + fn arcsec(&self) -> Complex; + + /// This function returns the complex arcsecant of the real number + /// z, \arcsec(z) = \arccos(1/z). + #[doc(alias = "gsl_complex_arcsec_real")] + fn arcsec_real(z: T) -> Complex; + + /// This function returns the complex arccosecant of the complex + /// number z, \arccsc(z) = \arcsin(1/z). + #[doc(alias = "gsl_complex_arccsc")] + fn arccsc(&self) -> Complex; + + /// This function returns the complex arccosecant of the real + /// number z, \arccsc(z) = \arcsin(1/z). + #[doc(alias = "gsl_complex_arccsc_real")] + fn arccsc_real(z: T) -> Complex; + + /// This function returns the complex arccotangent of the complex + /// number z, \arccot(z) = \arctan(1/z). + #[doc(alias = "gsl_complex_arccot")] + fn arccot(&self) -> Complex; + + /// This function returns the complex hyperbolic secant of the + /// complex number z, \sech(z) = 1/\cosh(z). + #[doc(alias = "gsl_complex_sech")] + fn sech(&self) -> Complex; + + /// This function returns the complex hyperbolic cosecant of the + /// complex number z, \csch(z) = 1/\sinh(z). + #[doc(alias = "gsl_complex_csch")] + fn csch(&self) -> Complex; + + /// This function returns the complex hyperbolic cotangent of the + /// complex number z, \coth(z) = 1/\tanh(z). + #[doc(alias = "gsl_complex_coth")] + fn coth(&self) -> Complex; + + /// This function returns the complex hyperbolic arcsine of the + /// complex number z, \arcsinh(z). The branch cuts are on the + /// imaginary axis, below -i and above i. + #[doc(alias = "gsl_complex_arcsinh")] + #[deprecated(since="8.0.0", note="please use `.asinh()` instead")] + fn arcsinh(&self) -> Complex; + + /// This function returns the complex hyperbolic arccosine of the + /// complex number z, \arccosh(z). The branch cut is on the real + /// axis, less than 1. Note that in this case we use the negative + /// square root in formula 4.6.21 of Abramowitz & Stegun giving + /// \arccosh(z)=\log(z-\sqrt{z^2-1}). + #[doc(alias = "gsl_complex_arccosh")] + #[deprecated(since="8.0.0", note="please use `.acosh()` instead")] + fn arccosh(&self) -> Complex; + + /// This function returns the complex hyperbolic arccosine of the + /// real number z, \arccosh(z). + #[doc(alias = "gsl_complex_arccosh_real")] + fn arccosh_real(z: T) -> Complex; + + /// This function returns the complex hyperbolic arctangent of the + /// complex number z, \arctanh(z). + /// + /// The branch cuts are on the real axis, less than -1 and greater than 1. + #[doc(alias = "gsl_complex_arctanh")] + #[deprecated(since="8.0.0", note="please use `.atanh()` instead")] + fn arctanh(&self) -> Complex; + + /// This function returns the complex hyperbolic arctangent of the + /// real number z, \arctanh(z). + #[doc(alias = "gsl_complex_arctanh_real")] + fn arctanh_real(z: T) -> Complex; + + /// This function returns the complex hyperbolic arcsecant of the + /// complex number z, \arcsech(z) = \arccosh(1/z). + #[doc(alias = "gsl_complex_arcsech")] + fn arcsech(&self) -> Complex; + + /// This function returns the complex hyperbolic arccosecant of + /// the complex number z, \arccsch(z) = \arcsin(1/z). + #[doc(alias = "gsl_complex_arccsch")] + fn arccsch(&self) -> Complex; + + /// This function returns the complex hyperbolic arccotangent of + /// the complex number z, \arccoth(z) = \arctanh(1/z). + #[doc(alias = "gsl_complex_arccoth")] + fn arccoth(&self) -> Complex; + + #[deprecated(since="8.0.0", note="please use `.re` instead")] + fn real(&self) -> T; + + #[deprecated(since="8.0.0", note="please use `.im` instead")] + fn imaginary(&self) -> T; +} + +impl ComplexOps for Complex { + fn rect(x: f64, y: f64) -> Complex { + unsafe { sys::gsl_complex_rect(x, y).wrap() } + } + + fn polar(r: f64, theta: f64) -> Complex { + unsafe { sys::gsl_complex_polar(r, theta).wrap() } + } + + fn abs(&self) -> f64 { + unsafe { sys::gsl_complex_abs(self.unwrap()) } + } + + fn abs2(&self) -> f64 { + unsafe { sys::gsl_complex_abs2(self.unwrap()) } + } + + fn logabs(&self) -> f64 { + unsafe { sys::gsl_complex_logabs(self.unwrap()) } + } + + fn add(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_add(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the difference of the complex numbers a and b, z=a-b. - #[doc(alias = "gsl_complex_sub")] - pub fn sub(&self, other: &ComplexF64) -> ComplexF64 { + fn sub(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_sub(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the product of the complex numbers a and b, z=ab. - #[doc(alias = "gsl_complex_mul")] - pub fn mul(&self, other: &ComplexF64) -> ComplexF64 { + fn mul(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_mul(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the quotient of the complex numbers a and b, z=a/b. - #[doc(alias = "gsl_complex_div")] - pub fn div(&self, other: &ComplexF64) -> ComplexF64 { + fn div(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_div(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the sum of the complex number a and the real number x, z=a+x. - #[doc(alias = "gsl_complex_add_real")] - pub fn add_real(&self, x: f64) -> ComplexF64 { + fn add_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_add_real(self.unwrap(), x).wrap() } } - /// This function returns the difference of the complex number a and the real number x, z=a-x. - #[doc(alias = "gsl_complex_sub_real")] - pub fn sub_real(&self, x: f64) -> ComplexF64 { + fn sub_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_sub_real(self.unwrap(), x).wrap() } } - /// This function returns the product of the complex number a and the real number x, z=ax. - #[doc(alias = "gsl_complex_mul_real")] - pub fn mul_real(&self, x: f64) -> ComplexF64 { + fn mul_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_mul_real(self.unwrap(), x).wrap() } } - /// This function returns the quotient of the complex number a and the real number x, z=a/x. - #[doc(alias = "gsl_complex_div_real")] - pub fn div_real(&self, x: f64) -> ComplexF64 { + fn div_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_div_real(self.unwrap(), x).wrap() } } - /// This function returns the sum of the complex number a and the imaginary number iy, z=a+iy. - #[doc(alias = "gsl_complex_add_imag")] - pub fn add_imag(&self, x: f64) -> ComplexF64 { + fn add_imag(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_add_imag(self.unwrap(), x).wrap() } } - /// This function returns the difference of the complex number a and the imaginary number iy, - /// z=a-iy. - #[doc(alias = "gsl_complex_sub_imag")] - pub fn sub_imag(&self, x: f64) -> ComplexF64 { + fn sub_imag(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_sub_imag(self.unwrap(), x).wrap() } } - /// This function returns the product of the complex number a and the imaginary number iy, - /// z=a*(iy). - #[doc(alias = "gsl_complex_mul_imag")] - pub fn mul_imag(&self, x: f64) -> ComplexF64 { + fn mul_imag(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_mul_imag(self.unwrap(), x).wrap() } } - /// This function returns the quotient of the complex number a and the imaginary number iy, - /// z=a/(iy). - #[doc(alias = "gsl_complex_div_imag")] - pub fn div_imag(&self, x: f64) -> ComplexF64 { + fn div_imag(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_div_imag(self.unwrap(), x).wrap() } } - /// This function returns the complex conjugate of the complex number z, z^* = x - i y. - #[doc(alias = "gsl_complex_conjugate")] - pub fn conjugate(&self) -> ComplexF64 { + fn conjugate(&self) -> Complex { unsafe { sys::gsl_complex_conjugate(self.unwrap()).wrap() } } - /// This function returns the inverse, or reciprocal, of the complex number z, 1/z = (x - i y)/ - /// (x^2 + y^2). - #[doc(alias = "gsl_complex_inverse")] - pub fn inverse(&self) -> ComplexF64 { + fn inverse(&self) -> Complex { unsafe { sys::gsl_complex_inverse(self.unwrap()).wrap() } } - /// This function returns the negative of the complex number z, -z = (-x) + i(-y). - #[doc(alias = "gsl_complex_negative")] - pub fn negative(&self) -> ComplexF64 { + fn negative(&self) -> Complex { unsafe { sys::gsl_complex_negative(self.unwrap()).wrap() } } - /// This function returns the square root of the complex number z, \sqrt z. - /// - /// The branch cut is the negative real axis. The result always lies in the right half of the - /// omplex plane. - #[doc(alias = "gsl_complex_sqrt")] - pub fn sqrt(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_sqrt(self.unwrap()).wrap() } - } - - /// This function returns the complex square root of the real number x, where x may be negative. - #[doc(alias = "gsl_complex_sqrt_real")] - pub fn sqrt_real(x: f64) -> ComplexF64 { + fn sqrt_real(x: f64) -> Complex { unsafe { sys::gsl_complex_sqrt_real(x).wrap() } } - /// The function returns the complex number z raised to the complex power a, z^a. - /// This is computed as \exp(\log(z)*a) using complex logarithms and complex exponentials. - #[doc(alias = "gsl_complex_pow")] - pub fn pow(&self, other: &ComplexF64) -> ComplexF64 { + fn pow(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_pow(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the complex number z raised to the real power x, z^x. - #[doc(alias = "gsl_complex_pow_real")] - pub fn pow_real(&self, x: f64) -> ComplexF64 { + fn pow_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_pow_real(self.unwrap(), x).wrap() } } - /// This function returns the complex exponential of the complex number z, \exp(z). - #[doc(alias = "gsl_complex_exp")] - pub fn exp(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_exp(self.unwrap()).wrap() } - } - - /// This function returns the complex natural logarithm (base e) of the complex number z, - /// \log(z). - /// - /// The branch cut is the negative real axis. - #[doc(alias = "gsl_complex_log")] - pub fn log(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_log(self.unwrap()).wrap() } - } - - /// This function returns the complex base-10 logarithm of the complex number z, \log_10 (z). - #[doc(alias = "gsl_complex_log10")] - pub fn log10(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_log10(self.unwrap()).wrap() } - } - - /// This function returns the complex base-b logarithm of the complex number z, \log_b(z). - /// This quantity is computed as the ratio \log(z)/\log(b). - #[doc(alias = "gsl_complex_log_b")] - pub fn log_b(&self, other: &ComplexF64) -> ComplexF64 { + fn log_b(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_log_b(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the complex sine of the complex number z, \sin(z) = (\exp(iz) - - /// \exp(-iz))/(2i). - #[doc(alias = "gsl_complex_sin")] - pub fn sin(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_sin(self.unwrap()).wrap() } - } - - /// This function returns the complex cosine of the complex number z, \cos(z) = (\exp(iz) + - /// \exp(-iz))/2. - #[doc(alias = "gsl_complex_cos")] - pub fn cos(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_cos(self.unwrap()).wrap() } - } - - /// This function returns the complex tangent of the complex number z, \tan(z) = - /// \sin(z)/\cos(z). - #[doc(alias = "gsl_complex_tan")] - pub fn tan(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_tan(self.unwrap()).wrap() } - } - - /// This function returns the complex secant of the complex number z, \sec(z) = 1/\cos(z). - #[doc(alias = "gsl_complex_sec")] - pub fn sec(&self) -> ComplexF64 { + fn sec(&self) -> Complex { unsafe { sys::gsl_complex_sec(self.unwrap()).wrap() } } - /// This function returns the complex cosecant of the complex number z, \csc(z) = 1/\sin(z). - #[doc(alias = "gsl_complex_csc")] - pub fn csc(&self) -> ComplexF64 { + fn csc(&self) -> Complex { unsafe { sys::gsl_complex_csc(self.unwrap()).wrap() } } - /// This function returns the complex cotangent of the complex number z, \cot(z) = 1/\tan(z). - #[doc(alias = "gsl_complex_cot")] - pub fn cot(&self) -> ComplexF64 { + fn cot(&self) -> Complex { unsafe { sys::gsl_complex_cot(self.unwrap()).wrap() } } - /// This function returns the complex arcsine of the complex number z, \arcsin(z). - /// The branch cuts are on the real axis, less than -1 and greater than 1. - #[doc(alias = "gsl_complex_arcsin")] - pub fn arcsin(&self) -> ComplexF64 { + fn arcsin(&self) -> Complex { unsafe { sys::gsl_complex_arcsin(self.unwrap()).wrap() } } - /// This function returns the complex arcsine of the real number z, \arcsin(z). - /// - /// * For z between -1 and 1, the function returns a real value in the range [-\pi/2,\pi/2]. - /// * For z less than -1 the result has a real part of -\pi/2 and a positive imaginary part. - /// * For z greater than 1 the result has a real part of \pi/2 and a negative imaginary part. - #[doc(alias = "gsl_complex_arcsin_real")] - pub fn arcsin_real(z: f64) -> ComplexF64 { + fn arcsin_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arcsin_real(z).wrap() } } - /// This function returns the complex arccosine of the complex number z, \arccos(z). - /// The branch cuts are on the real axis, less than -1 and greater than 1. - #[doc(alias = "gsl_complex_arccos")] - pub fn arccos(&self) -> ComplexF64 { + fn arccos(&self) -> Complex { unsafe { sys::gsl_complex_arccos(self.unwrap()).wrap() } } - /// This function returns the complex arccosine of the real number z, \arccos(z). - /// - /// * For z between -1 and 1, the function returns a real value in the range [0,\pi]. - /// * For z less than -1 the result has a real part of \pi and a negative imaginary part. - /// * For z greater than 1 the result is purely imaginary and positive. - #[doc(alias = "gsl_complex_arccos_real")] - pub fn arccos_real(z: f64) -> ComplexF64 { + fn arccos_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arccos_real(z).wrap() } } - /// This function returns the complex arctangent of the complex number z, \arctan(z). - /// The branch cuts are on the imaginary axis, below -i and above i. - #[doc(alias = "gsl_complex_arctan")] - pub fn arctan(&self) -> ComplexF64 { + fn arctan(&self) -> Complex { unsafe { sys::gsl_complex_arctan(self.unwrap()).wrap() } } - /// This function returns the complex arcsecant of the complex number z, \arcsec(z) = - /// \arccos(1/z). - #[doc(alias = "gsl_complex_arcsec")] - pub fn arcsec(&self) -> ComplexF64 { + fn arcsec(&self) -> Complex { unsafe { sys::gsl_complex_arcsec(self.unwrap()).wrap() } } - /// This function returns the complex arcsecant of the real number z, \arcsec(z) = \arccos(1/z). - #[doc(alias = "gsl_complex_arcsec_real")] - pub fn arcsec_real(z: f64) -> ComplexF64 { + fn arcsec_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arcsec_real(z).wrap() } } - /// This function returns the complex arccosecant of the complex number z, \arccsc(z) = - /// \arcsin(1/z). - #[doc(alias = "gsl_complex_arccsc")] - pub fn arccsc(&self) -> ComplexF64 { + fn arccsc(&self) -> Complex { unsafe { sys::gsl_complex_arccsc(self.unwrap()).wrap() } } - /// This function returns the complex arccosecant of the real number z, \arccsc(z) = - /// \arcsin(1/z). - #[doc(alias = "gsl_complex_arccsc_real")] - pub fn arccsc_real(z: f64) -> ComplexF64 { + fn arccsc_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arccsc_real(z).wrap() } } - /// This function returns the complex arccotangent of the complex number z, \arccot(z) = - /// \arctan(1/z). - #[doc(alias = "gsl_complex_arccot")] - pub fn arccot(&self) -> ComplexF64 { + fn arccot(&self) -> Complex { unsafe { sys::gsl_complex_arccot(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic sine of the complex number z, \sinh(z) = - /// (\exp(z) - \exp(-z))/2. - #[doc(alias = "gsl_complex_sinh")] - pub fn sinh(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_sinh(self.unwrap()).wrap() } - } - - /// This function returns the complex hyperbolic cosine of the complex number z, \cosh(z) = - /// (\exp(z) + \exp(-z))/2. - #[doc(alias = "gsl_complex_cosh")] - pub fn cosh(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_cosh(self.unwrap()).wrap() } - } - - /// This function returns the complex hyperbolic tangent of the complex number z, \tanh(z) = - /// \sinh(z)/\cosh(z). - #[doc(alias = "gsl_complex_tanh")] - pub fn tanh(&self) -> ComplexF64 { - unsafe { sys::gsl_complex_tanh(self.unwrap()).wrap() } - } - - /// This function returns the complex hyperbolic secant of the complex number z, \sech(z) = - /// 1/\cosh(z). - #[doc(alias = "gsl_complex_sech")] - pub fn sech(&self) -> ComplexF64 { + fn sech(&self) -> Complex { unsafe { sys::gsl_complex_sech(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic cosecant of the complex number z, \csch(z) = - /// 1/\sinh(z). - #[doc(alias = "gsl_complex_csch")] - pub fn csch(&self) -> ComplexF64 { + fn csch(&self) -> Complex { unsafe { sys::gsl_complex_csch(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic cotangent of the complex number z, \coth(z) = - /// 1/\tanh(z). - #[doc(alias = "gsl_complex_coth")] - pub fn coth(&self) -> ComplexF64 { + fn coth(&self) -> Complex { unsafe { sys::gsl_complex_coth(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arcsine of the complex number z, \arcsinh(z). - /// The branch cuts are on the imaginary axis, below -i and above i. - #[doc(alias = "gsl_complex_arcsinh")] - pub fn arcsinh(&self) -> ComplexF64 { + fn arcsinh(&self) -> Complex { unsafe { sys::gsl_complex_arcsinh(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arccosine of the complex number z, \arccosh(z). - /// The branch cut is on the real axis, less than 1. - /// Note that in this case we use the negative square root in formula 4.6.21 of Abramowitz & - /// Stegun giving \arccosh(z)=\log(z-\sqrt{z^2-1}). - #[doc(alias = "gsl_complex_arccosh")] - pub fn arccosh(&self) -> ComplexF64 { + fn arccosh(&self) -> Complex { unsafe { sys::gsl_complex_arccosh(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arccosine of the real number z, \arccosh(z). - #[doc(alias = "gsl_complex_arccosh_real")] - pub fn arccosh_real(z: f64) -> ComplexF64 { + fn arccosh_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arccosh_real(z).wrap() } } - /// This function returns the complex hyperbolic arctangent of the complex number z, - /// \arctanh(z). - /// - /// The branch cuts are on the real axis, less than -1 and greater than 1. - #[doc(alias = "gsl_complex_arctanh")] - pub fn arctanh(&self) -> ComplexF64 { + fn arctanh(&self) -> Complex { unsafe { sys::gsl_complex_arctanh(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arctangent of the real number z, \arctanh(z). - #[doc(alias = "gsl_complex_arctanh_real")] - pub fn arctanh_real(z: f64) -> ComplexF64 { + fn arctanh_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arctanh_real(z).wrap() } } - /// This function returns the complex hyperbolic arcsecant of the complex number z, \arcsech(z) - /// = \arccosh(1/z). - #[doc(alias = "gsl_complex_arcsech")] - pub fn arcsech(&self) -> ComplexF64 { + fn arcsech(&self) -> Complex { unsafe { sys::gsl_complex_arcsech(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arccosecant of the complex number z, - /// \arccsch(z) = \arcsin(1/z). - #[doc(alias = "gsl_complex_arccsch")] - pub fn arccsch(&self) -> ComplexF64 { + fn arccsch(&self) -> Complex { unsafe { sys::gsl_complex_arccsch(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arccotangent of the complex number z, - /// \arccoth(z) = \arctanh(1/z). - #[doc(alias = "gsl_complex_arccoth")] - pub fn arccoth(&self) -> ComplexF64 { + fn arccoth(&self) -> Complex { unsafe { sys::gsl_complex_arccoth(self.unwrap()).wrap() } } - pub fn real(&self) -> f64 { - self.dat[0] + fn real(&self) -> f64 { + self.re } - pub fn imaginary(&self) -> f64 { - self.dat[1] + fn imaginary(&self) -> f64 { + self.im } } -impl Debug for ComplexF64 { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "[{}, {}]", self.dat[0], self.dat[1]) - } -} -impl Default for ComplexF64 { - fn default() -> ComplexF64 { - ComplexF64 { dat: [0f64, 0f64] } +// The GLS Complex module does not support `f32` operations. Thus we +// convert back and forth to `f64`. +impl ToC for Complex { + fn unwrap(self) -> sys::gsl_complex { + sys::gsl_complex { dat: [self.re as f64, self.im as f64]} } } -impl CFFI for ComplexF64 { - fn wrap(t: sys::gsl_complex) -> ComplexF64 { - unsafe { std::mem::transmute(t) } - } - - fn unwrap(self) -> sys::gsl_complex { +// For use by other modules (e.g. `blas`). +impl FromC> for sys::gsl_complex_float { + fn wrap(self) -> Complex { + // Complex is memory layout compatible with [T; 2] unsafe { std::mem::transmute(self) } } } -impl CFFI for ComplexF64 { - fn wrap(t: sys::gsl_complex_float) -> ComplexF64 { - ComplexF64 { - dat: [t.dat[0] as f64, t.dat[1] as f64], - } - } - +impl ToC for Complex { fn unwrap(self) -> sys::gsl_complex_float { - sys::gsl_complex_float { - dat: [self.dat[0] as f32, self.dat[1] as f32], - } - } -} - -impl FFFI for sys::gsl_complex { - fn wrap(self) -> ComplexF32 { - ComplexF32 { - dat: [self.dat[0] as f32, self.dat[1] as f32], - } - } - - fn unwrap(t: ComplexF32) -> sys::gsl_complex { - sys::gsl_complex { - dat: [t.dat[0] as f64, t.dat[1] as f64], - } - } -} - -impl FFFI for sys::gsl_complex { - fn wrap(self) -> ComplexF64 { unsafe { std::mem::transmute(self) } } - - fn unwrap(t: ComplexF64) -> sys::gsl_complex { - unsafe { std::mem::transmute(t) } - } } -#[repr(C)] -#[derive(Clone, Copy, PartialEq)] -pub struct ComplexF32 { - pub dat: [f32; 2], +impl FromC> for sys::gsl_complex { + fn wrap(self) -> Complex { + let [re, im] = self.dat; + Complex { re: re as f32, im: im as f32 } + } } -impl ComplexF32 { - /// This function uses the rectangular Cartesian components (x,y) to return the complex number - /// z = x + i y. - #[doc(alias = "gsl_complex_rect")] - pub fn rect(x: f32, y: f32) -> ComplexF32 { +impl ComplexOps for Complex { + fn rect(x: f32, y: f32) -> Complex { unsafe { sys::gsl_complex_rect(x as f64, y as f64).wrap() } } - /// This function returns the complex number z = r \exp(i \theta) = r (\cos(\theta) + i - /// \sin(\theta)) from the polar representation (r,theta). - #[doc(alias = "gsl_complex_polar")] - pub fn polar(r: f32, theta: f32) -> ComplexF32 { + fn polar(r: f32, theta: f32) -> Complex { unsafe { sys::gsl_complex_polar(r as f64, theta as f64).wrap() } } - /// This function returns the argument of the complex number z, \arg(z), where -\pi < \arg(z) - /// <= \pi. - #[doc(alias = "gsl_complex_arg")] - pub fn arg(&self) -> f32 { - unsafe { sys::gsl_complex_arg(self.unwrap()) as f32 } - } - - /// This function returns the magnitude of the complex number z, |z|. - #[doc(alias = "gsl_complex_abs")] - pub fn abs(&self) -> f32 { + fn abs(&self) -> f32 { unsafe { sys::gsl_complex_abs(self.unwrap()) as f32 } } - /// This function returns the squared magnitude of the complex number z, |z|^2. - #[doc(alias = "gsl_complex_abs2")] - pub fn abs2(&self) -> f32 { + fn abs2(&self) -> f32 { unsafe { sys::gsl_complex_abs2(self.unwrap()) as f32 } } - /// This function returns the natural logarithm of the magnitude of the complex number z, - /// \log|z|. - /// - /// It allows an accurate evaluation of \log|z| when |z| is close to one. - /// The direct evaluation of log(gsl_complex_abs(z)) would lead to a loss of precision in - /// this case. - #[doc(alias = "gsl_complex_logabs")] - pub fn logabs(&self) -> f32 { + fn logabs(&self) -> f32 { unsafe { sys::gsl_complex_logabs(self.unwrap()) as f32 } } - /// This function returns the sum of the complex numbers a and b, z=a+b. - #[doc(alias = "gsl_complex_add")] - pub fn add(&self, other: &ComplexF32) -> ComplexF32 { + fn add(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_add(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the difference of the complex numbers a and b, z=a-b. - #[doc(alias = "gsl_complex_sub")] - pub fn sub(&self, other: &ComplexF32) -> ComplexF32 { + fn sub(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_sub(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the product of the complex numbers a and b, z=ab. - #[doc(alias = "gsl_complex_mul")] - pub fn mul(&self, other: &ComplexF32) -> ComplexF32 { + fn mul(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_mul(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the quotient of the complex numbers a and b, z=a/b. - #[doc(alias = "gsl_complex_div")] - pub fn div(&self, other: &ComplexF32) -> ComplexF32 { + fn div(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_div(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the sum of the complex number a and the real number x, z=a+x. - #[doc(alias = "gsl_complex_add_real")] - pub fn add_real(&self, x: f32) -> ComplexF32 { + fn add_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_add_real(self.unwrap(), x as f64).wrap() } } - /// This function returns the difference of the complex number a and the real number x, z=a-x. - #[doc(alias = "gsl_complex_sub_real")] - pub fn sub_real(&self, x: f32) -> ComplexF32 { + fn sub_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_sub_real(self.unwrap(), x as f64).wrap() } } - /// This function returns the product of the complex number a and the real number x, z=ax. - #[doc(alias = "gsl_complex_mul_real")] - pub fn mul_real(&self, x: f32) -> ComplexF32 { + fn mul_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_mul_real(self.unwrap(), x as f64).wrap() } } - /// This function returns the quotient of the complex number a and the real number x, z=a/x. - #[doc(alias = "gsl_complex_div_real")] - pub fn div_real(&self, x: f32) -> ComplexF32 { + fn div_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_div_real(self.unwrap(), x as f64).wrap() } } - /// This function returns the sum of the complex number a and the imaginary number iy, z=a+iy. - #[doc(alias = "gsl_complex_add_imag")] - pub fn add_imag(&self, x: f32) -> ComplexF32 { + fn add_imag(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_add_imag(self.unwrap(), x as f64).wrap() } } - /// This function returns the difference of the complex number a and the imaginary number iy, z=a-iy. - #[doc(alias = "gsl_complex_sub_imag")] - pub fn sub_imag(&self, x: f32) -> ComplexF32 { + fn sub_imag(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_sub_imag(self.unwrap(), x as f64).wrap() } } - /// This function returns the product of the complex number a and the imaginary number iy, z=a*(iy). - #[doc(alias = "gsl_complex_mul_imag")] - pub fn mul_imag(&self, x: f32) -> ComplexF32 { + fn mul_imag(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_mul_imag(self.unwrap(), x as f64).wrap() } } - /// This function returns the quotient of the complex number a and the imaginary number iy, z=a/(iy). - #[doc(alias = "gsl_complex_div_imag")] - pub fn div_imag(&self, x: f32) -> ComplexF32 { + fn div_imag(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_div_imag(self.unwrap(), x as f64).wrap() } } - /// This function returns the complex conjugate of the complex number z, z^* = x - i y. - #[doc(alias = "gsl_complex_conjugate")] - pub fn conjugate(&self) -> ComplexF32 { + fn conjugate(&self) -> Complex { unsafe { sys::gsl_complex_conjugate(self.unwrap()).wrap() } } - /// This function returns the inverse, or reciprocal, of the complex number z, 1/z = (x - i y)/ - /// (x^2 + y^2). - #[doc(alias = "gsl_complex_inverse")] - pub fn inverse(&self) -> ComplexF32 { + fn inverse(&self) -> Complex { unsafe { sys::gsl_complex_inverse(self.unwrap()).wrap() } } - /// This function returns the negative of the complex number z, -z = (-x) + i(-y). - #[doc(alias = "gsl_complex_negative")] - pub fn negative(&self) -> ComplexF32 { + fn negative(&self) -> Complex { unsafe { sys::gsl_complex_negative(self.unwrap()).wrap() } } - /// This function returns the square root of the complex number z, \sqrt z. - /// - /// The branch cut is the negative real axis. The result always lies in the right half of the - /// complex plane. - #[doc(alias = "gsl_complex_sqrt")] - pub fn sqrt(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_sqrt(self.unwrap()).wrap() } - } - - /// This function returns the complex square root of the real number x, where x may be negative. - #[doc(alias = "gsl_complex_sqrt_real")] - pub fn sqrt_real(x: f32) -> ComplexF32 { + fn sqrt_real(x: f32) -> Complex { unsafe { sys::gsl_complex_sqrt_real(x as f64).wrap() } } - /// The function returns the complex number z raised to the complex power a, z^a. - /// - /// This is computed as \exp(\log(z)*a) using complex logarithms and complex exponentials. - #[doc(alias = "gsl_complex_pow")] - pub fn pow(&self, other: &ComplexF32) -> ComplexF32 { + fn pow(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_pow(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the complex number z raised to the real power x, z^x. - #[doc(alias = "gsl_complex_pow_real")] - pub fn pow_real(&self, x: f32) -> ComplexF32 { + fn pow_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_pow_real(self.unwrap(), x as f64).wrap() } } - /// This function returns the complex exponential of the complex number z, \exp(z). - #[doc(alias = "gsl_complex_exp")] - pub fn exp(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_exp(self.unwrap()).wrap() } - } - - /// This function returns the complex natural logarithm (base e) of the complex number z, \log(z). - /// The branch cut is the negative real axis. - #[doc(alias = "gsl_complex_log")] - pub fn log(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_log(self.unwrap()).wrap() } - } - - /// This function returns the complex base-10 logarithm of the complex number z, \log_10 (z). - #[doc(alias = "gsl_complex_log10")] - pub fn log10(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_log10(self.unwrap()).wrap() } - } - - /// This function returns the complex base-b logarithm of the complex number z, \log_b(z). - /// This quantity is computed as the ratio \log(z)/\log(b). - #[doc(alias = "gsl_complex_log_b")] - pub fn log_b(&self, other: &ComplexF32) -> ComplexF32 { + fn log_b(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_log_b(self.unwrap(), other.unwrap()).wrap() } } - /// This function returns the complex sine of the complex number z, \sin(z) = (\exp(iz) - - /// \exp(-iz))/(2i). - #[doc(alias = "gsl_complex_sin")] - pub fn sin(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_sin(self.unwrap()).wrap() } - } - - /// This function returns the complex cosine of the complex number z, \cos(z) = (\exp(iz) + - /// \exp(-iz))/2. - #[doc(alias = "gsl_complex_cos")] - pub fn cos(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_cos(self.unwrap()).wrap() } - } - - /// This function returns the complex tangent of the complex number z, \tan(z) = - /// \sin(z)/\cos(z). - #[doc(alias = "gsl_complex_tan")] - pub fn tan(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_tan(self.unwrap()).wrap() } - } - - /// This function returns the complex secant of the complex number z, \sec(z) = 1/\cos(z). - #[doc(alias = "gsl_complex_sec")] - pub fn sec(&self) -> ComplexF32 { + fn sec(&self) -> Complex { unsafe { sys::gsl_complex_sec(self.unwrap()).wrap() } } - /// This function returns the complex cosecant of the complex number z, \csc(z) = 1/\sin(z). - #[doc(alias = "gsl_complex_csc")] - pub fn csc(&self) -> ComplexF32 { + fn csc(&self) -> Complex { unsafe { sys::gsl_complex_csc(self.unwrap()).wrap() } } - /// This function returns the complex cotangent of the complex number z, \cot(z) = 1/\tan(z). - #[doc(alias = "gsl_complex_cot")] - pub fn cot(&self) -> ComplexF32 { + fn cot(&self) -> Complex { unsafe { sys::gsl_complex_cot(self.unwrap()).wrap() } } - /// This function returns the complex arcsine of the complex number z, \arcsin(z). - /// The branch cuts are on the real axis, less than -1 and greater than 1. - #[doc(alias = "gsl_complex_arcsin")] - pub fn arcsin(&self) -> ComplexF32 { + fn arcsin(&self) -> Complex { unsafe { sys::gsl_complex_arcsin(self.unwrap()).wrap() } } - /// This function returns the complex arcsine of the real number z, \arcsin(z). - /// - /// * For z between -1 and 1, the function returns a real value in the range [-\pi/2,\pi/2]. - /// * For z less than -1 the result has a real part of -\pi/2 and a positive imaginary part. - /// * For z greater than 1 the result has a real part of \pi/2 and a negative imaginary part. - #[doc(alias = "gsl_complex_arcsin_real")] - pub fn arcsin_real(z: f32) -> ComplexF32 { + fn arcsin_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arcsin_real(z as f64).wrap() } } - /// This function returns the complex arccosine of the complex number z, \arccos(z). - /// The branch cuts are on the real axis, less than -1 and greater than 1. - #[doc(alias = "gsl_complex_arccos")] - pub fn arccos(&self) -> ComplexF32 { + fn arccos(&self) -> Complex { unsafe { sys::gsl_complex_arccos(self.unwrap()).wrap() } } - /// This function returns the complex arccosine of the real number z, \arccos(z). - /// - /// * For z between -1 and 1, the function returns a real value in the range [0,\pi]. - /// * For z less than -1 the result has a real part of \pi and a negative imaginary part. - /// * For z greater than 1 the result is purely imaginary and positive. - #[doc(alias = "gsl_complex_arccos_real")] - pub fn arccos_real(z: f32) -> ComplexF32 { + fn arccos_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arccos_real(z as f64).wrap() } } - /// This function returns the complex arctangent of the complex number z, \arctan(z). - /// The branch cuts are on the imaginary axis, below -i and above i. - #[doc(alias = "gsl_complex_arctan")] - pub fn arctan(&self) -> ComplexF32 { + fn arctan(&self) -> Complex { unsafe { sys::gsl_complex_arctan(self.unwrap()).wrap() } } - /// This function returns the complex arcsecant of the complex number z, \arcsec(z) = - /// \arccos(1/z). - #[doc(alias = "gsl_complex_arcsec")] - pub fn arcsec(&self) -> ComplexF32 { + fn arcsec(&self) -> Complex { unsafe { sys::gsl_complex_arcsec(self.unwrap()).wrap() } } - /// This function returns the complex arcsecant of the real number z, \arcsec(z) = \arccos(1/z). - #[doc(alias = "gsl_complex_arcsec_real")] - pub fn arcsec_real(z: f32) -> ComplexF32 { + fn arcsec_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arcsec_real(z as f64).wrap() } } - /// This function returns the complex arccosecant of the complex number z, \arccsc(z) = - /// \arcsin(1/z). - #[doc(alias = "gsl_complex_arccsc")] - pub fn arccsc(&self) -> ComplexF32 { + fn arccsc(&self) -> Complex { unsafe { sys::gsl_complex_arccsc(self.unwrap()).wrap() } } - /// This function returns the complex arccosecant of the real number z, \arccsc(z) = - /// \arcsin(1/z). - #[doc(alias = "gsl_complex_arccsc_real")] - pub fn arccsc_real(z: f32) -> ComplexF32 { + fn arccsc_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arccsc_real(z as f64).wrap() } } - /// This function returns the complex arccotangent of the complex number z, \arccot(z) = - /// \arctan(1/z). - #[doc(alias = "gsl_complex_arccot")] - pub fn arccot(&self) -> ComplexF32 { + fn arccot(&self) -> Complex { unsafe { sys::gsl_complex_arccot(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic sine of the complex number z, \sinh(z) = - /// (\exp(z) - \exp(-z))/2. - #[doc(alias = "gsl_complex_sinh")] - pub fn sinh(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_sinh(self.unwrap()).wrap() } - } - - /// This function returns the complex hyperbolic cosine of the complex number z, \cosh(z) = - /// (\exp(z) + \exp(-z))/2. - #[doc(alias = "gsl_complex_cosh")] - pub fn cosh(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_cosh(self.unwrap()).wrap() } - } - - /// This function returns the complex hyperbolic tangent of the complex number z, \tanh(z) = - /// \sinh(z)/\cosh(z). - #[doc(alias = "gsl_complex_tanh")] - pub fn tanh(&self) -> ComplexF32 { - unsafe { sys::gsl_complex_tanh(self.unwrap()).wrap() } - } - - /// This function returns the complex hyperbolic secant of the complex number z, \sech(z) = - /// 1/\cosh(z). - #[doc(alias = "gsl_complex_sech")] - pub fn sech(&self) -> ComplexF32 { + fn sech(&self) -> Complex { unsafe { sys::gsl_complex_sech(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic cosecant of the complex number z, \csch(z) = - /// 1/\sinh(z). - #[doc(alias = "gsl_complex_csch")] - pub fn csch(&self) -> ComplexF32 { + fn csch(&self) -> Complex { unsafe { sys::gsl_complex_csch(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic cotangent of the complex number z, \coth(z) = - /// 1/\tanh(z). - #[doc(alias = "gsl_complex_coth")] - pub fn coth(&self) -> ComplexF32 { + fn coth(&self) -> Complex { unsafe { sys::gsl_complex_coth(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arcsine of the complex number z, \arcsinh(z). - /// The branch cuts are on the imaginary axis, below -i and above i. - #[doc(alias = "gsl_complex_arcsinh")] - pub fn arcsinh(&self) -> ComplexF32 { + fn arcsinh(&self) -> Complex { unsafe { sys::gsl_complex_arcsinh(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arccosine of the complex number z, \arccosh(z). - /// - /// The branch cut is on the real axis, less than 1. - /// - /// Note that in this case we use the negative square root in formula 4.6.21 of Abramowitz & - /// Stegun giving \arccosh(z)=\log(z-\sqrt{z^2-1}). - #[doc(alias = "gsl_complex_arccosh")] - pub fn arccosh(&self) -> ComplexF32 { + fn arccosh(&self) -> Complex { unsafe { sys::gsl_complex_arccosh(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arccosine of the real number z, \arccosh(z). - #[doc(alias = "gsl_complex_arccosh_real")] - pub fn arccosh_real(z: f32) -> ComplexF32 { + fn arccosh_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arccosh_real(z as f64).wrap() } } - /// This function returns the complex hyperbolic arctangent of the complex number z, - /// arctanh(z). - /// - /// The branch cuts are on the real axis, less than -1 and greater than 1. - #[doc(alias = "gsl_complex_arctanh")] - pub fn arctanh(&self) -> ComplexF32 { + fn arctanh(&self) -> Complex { unsafe { sys::gsl_complex_arctanh(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arctangent of the real number z, \arctanh(z). - #[doc(alias = "gsl_complex_arctanh_real")] - pub fn arctanh_real(z: f32) -> ComplexF32 { + fn arctanh_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arctanh_real(z as f64).wrap() } } - /// This function returns the complex hyperbolic arcsecant of the complex number z, \arcsech(z) - /// = \arccosh(1/z). - #[doc(alias = "gsl_complex_arcsech")] - pub fn arcsech(&self) -> ComplexF32 { + fn arcsech(&self) -> Complex { unsafe { sys::gsl_complex_arcsech(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arccosecant of the complex number z, - /// \arccsch(z) = \arcsin(1/z). - #[doc(alias = "gsl_complex_arccsch")] - pub fn arccsch(&self) -> ComplexF32 { + fn arccsch(&self) -> Complex { unsafe { sys::gsl_complex_arccsch(self.unwrap()).wrap() } } - /// This function returns the complex hyperbolic arccotangent of the complex number z, - /// \arccoth(z) = \arctanh(1/z). - #[doc(alias = "gsl_complex_arccoth")] - pub fn arccoth(&self) -> ComplexF32 { + fn arccoth(&self) -> Complex { unsafe { sys::gsl_complex_arccoth(self.unwrap()).wrap() } } - pub fn real(&self) -> f32 { - self.dat[0] - } - - pub fn imaginary(&self) -> f32 { - self.dat[1] - } -} - -impl Debug for ComplexF32 { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "[{}, {}]", self.dat[0], self.dat[1]) - } -} - -impl Default for ComplexF32 { - fn default() -> ComplexF32 { - ComplexF32 { dat: [0f32, 0f32] } - } -} - -impl CFFI for ComplexF32 { - fn wrap(s: sys::gsl_complex) -> ComplexF32 { - ComplexF32 { - dat: [s.dat[0] as f32, s.dat[1] as f32], - } + fn real(&self) -> f32 { + self.re } - fn unwrap(self) -> sys::gsl_complex { - sys::gsl_complex { - dat: [self.dat[0] as f64, self.dat[1] as f64], - } + fn imaginary(&self) -> f32 { + self.im } } -impl CFFI for ComplexF32 { - fn wrap(s: sys::gsl_complex_float) -> ComplexF32 { - unsafe { std::mem::transmute(s) } - } - fn unwrap(self) -> sys::gsl_complex_float { - unsafe { std::mem::transmute(self) } +#[cfg(test)] +mod tests { + // All these tests have been tested against the following C code: + // + // ```ignore + // #include + // #include + // #include + // #include + // + // void print_complex(gsl_complex *c, const char *text) { + // printf("%s: %f %f\n", text, c->dat[0], c->dat[1]); + // } + // + // int main (void) + // { + // gsl_complex c = gsl_complex_rect(10., 10.); + // gsl_complex c2 = gsl_complex_rect(1., -1.); + // print_complex(&c, "rect"); + // print_complex(&c2, "rect"); + // gsl_complex c3 = gsl_complex_polar(5., 7.); + // print_complex(&c3, "polar"); + // + // printf("-> %f\n", gsl_complex_arg(c3)); + // printf("-> %f\n", gsl_complex_abs(c3)); + // printf("-> %f\n", gsl_complex_abs2(c3)); + // printf("-> %f\n", gsl_complex_logabs(c3)); + // + // { + // gsl_complex c4 = gsl_complex_add(c3, c2); + // print_complex(&c4, "\nadd"); + // } + // { + // gsl_complex c4 = gsl_complex_sub(c3, c2); + // print_complex(&c4, "sub"); + // } + // { + // gsl_complex c4 = gsl_complex_mul(c3, c2); + // print_complex(&c4, "mul"); + // } + // { + // gsl_complex c4 = gsl_complex_div(c3, c2); + // print_complex(&c4, "div"); + // } + // { + // gsl_complex c4 = gsl_complex_add_real(c3, 5.); + // print_complex(&c4, "add_real"); + // } + // { + // gsl_complex c4 = gsl_complex_sub_real(c3, 5.); + // print_complex(&c4, "sub_real"); + // } + // { + // gsl_complex c4 = gsl_complex_mul_real(c3, 5.); + // print_complex(&c4, "mul_real"); + // } + // { + // gsl_complex c4 = gsl_complex_div_real(c3, 5.); + // print_complex(&c4, "div_real"); + // } + // + // + // { + // gsl_complex c4 = gsl_complex_add_imag(c3, 5.); + // print_complex(&c4, "\nadd_imag"); + // } + // { + // gsl_complex c4 = gsl_complex_sub_imag(c3, 5.); + // print_complex(&c4, "sub_imag"); + // } + // { + // gsl_complex c4 = gsl_complex_mul_imag(c3, 5.); + // print_complex(&c4, "mul_imag"); + // } + // { + // gsl_complex c4 = gsl_complex_div_imag(c3, 5.); + // print_complex(&c4, "div_imag"); + // } + // + // + // { + // gsl_complex c4 = gsl_complex_conjugate(c3); + // print_complex(&c4, "\nconjugate"); + // } + // { + // gsl_complex c4 = gsl_complex_inverse(c3); + // print_complex(&c4, "inverse"); + // } + // { + // gsl_complex c4 = gsl_complex_negative(c3); + // print_complex(&c4, "negative"); + // } + // { + // gsl_complex c4 = gsl_complex_sqrt(c3); + // print_complex(&c4, "sqrt"); + // } + // { + // gsl_complex c4 = gsl_complex_sqrt_real(5.); + // print_complex(&c4, "sqrt_real"); + // } + // + // + // { + // gsl_complex c4 = gsl_complex_pow(c3, c2); + // print_complex(&c4, "\npow"); + // } + // { + // gsl_complex c4 = gsl_complex_pow_real(c3, 5.); + // print_complex(&c4, "pow_real"); + // } + // { + // gsl_complex c4 = gsl_complex_exp(c3); + // print_complex(&c4, "exp"); + // } + // { + // gsl_complex c4 = gsl_complex_log(c3); + // print_complex(&c4, "log"); + // } + // { + // gsl_complex c4 = gsl_complex_log10(c3); + // print_complex(&c4, "log10"); + // } + // { + // gsl_complex c4 = gsl_complex_log_b(c3, c2); + // print_complex(&c4, "log_b"); + // } + // { + // gsl_complex c4 = gsl_complex_sin(c3); + // print_complex(&c4, "sin"); + // } + // { + // gsl_complex c4 = gsl_complex_cos(c3); + // print_complex(&c4, "cos"); + // } + // { + // gsl_complex c4 = gsl_complex_tan(c3); + // print_complex(&c4, "tan"); + // } + // + // + // { + // gsl_complex c4 = gsl_complex_sec(c3); + // print_complex(&c4, "\nsec"); + // } + // { + // gsl_complex c4 = gsl_complex_csc(c3); + // print_complex(&c4, "csc"); + // } + // { + // gsl_complex c4 = gsl_complex_cot(c3); + // print_complex(&c4, "cot"); + // } + // { + // gsl_complex c4 = gsl_complex_arcsin(c3); + // print_complex(&c4, "arcsin"); + // } + // { + // gsl_complex c4 = gsl_complex_arcsin_real(5.); + // print_complex(&c4, "arcsin_real"); + // } + // { + // gsl_complex c4 = gsl_complex_arccos(c3); + // print_complex(&c4, "arccos"); + // } + // { + // gsl_complex c4 = gsl_complex_arccos_real(5.); + // print_complex(&c4, "arccos_real"); + // } + // { + // gsl_complex c4 = gsl_complex_arctan(c3); + // print_complex(&c4, "arctan"); + // } + // { + // gsl_complex c4 = gsl_complex_arcsec(c3); + // print_complex(&c4, "arcsec"); + // } + // { + // gsl_complex c4 = gsl_complex_arcsec_real(5.); + // print_complex(&c4, "arcsec_real"); + // } + // { + // gsl_complex c4 = gsl_complex_arccsc(c3); + // print_complex(&c4, "arccsc"); + // } + // { + // gsl_complex c4 = gsl_complex_arccsc_real(5.); + // print_complex(&c4, "arccsc_real"); + // } + // { + // gsl_complex c4 = gsl_complex_arccot(c3); + // print_complex(&c4, "arccot"); + // } + // { + // gsl_complex c4 = gsl_complex_sinh(c3); + // print_complex(&c4, "sinh"); + // } + // { + // gsl_complex c4 = gsl_complex_cosh(c3); + // print_complex(&c4, "cosh"); + // } + // { + // gsl_complex c4 = gsl_complex_tanh(c3); + // print_complex(&c4, "tanh"); + // } + // { + // gsl_complex c4 = gsl_complex_sech(c3); + // print_complex(&c4, "sech"); + // } + // { + // gsl_complex c4 = gsl_complex_csch(c3); + // print_complex(&c4, "csch"); + // } + // { + // gsl_complex c4 = gsl_complex_coth(c3); + // print_complex(&c4, "coth"); + // } + // { + // gsl_complex c4 = gsl_complex_arcsinh(c3); + // print_complex(&c4, "arcsinh"); + // } + // { + // gsl_complex c4 = gsl_complex_arccosh(c3); + // print_complex(&c4, "arccosh"); + // } + // { + // gsl_complex c4 = gsl_complex_arccosh_real(5.); + // print_complex(&c4, "arccosh_real"); + // } + // { + // gsl_complex c4 = gsl_complex_arctanh(c3); + // print_complex(&c4, "arctanh"); + // } + // { + // gsl_complex c4 = gsl_complex_arctanh_real(5.); + // print_complex(&c4, "arctanh_real"); + // } + // { + // gsl_complex c4 = gsl_complex_arcsech(c3); + // print_complex(&c4, "arcsech"); + // } + // { + // gsl_complex c4 = gsl_complex_arccsch(c3); + // print_complex(&c4, "arccsch"); + // } + // { + // gsl_complex c4 = gsl_complex_arccoth(c3); + // print_complex(&c4, "arccoth"); + // } + // return 0; + // } + // ``` + + #[test] + #[allow(deprecated)] + fn complex_f64() { + type C = num_complex::Complex; + use crate::complex::ComplexOps; + + let v = C::rect(10., 10.); + assert_eq!(v, C { re: 10., im: 10. }); + let v2 = C::rect(1., -1.); + assert_eq!(v2, C { re: 1., im: -1. }); + let v = C::polar(5., 7.); + assert_eq!( + format!("{:.4} {:.4}", v.re, v.im), + "3.7695 3.2849".to_owned() + ); + + let arg = v.arg(); + assert_eq!(format!("{:.4}", arg), "0.7168".to_owned()); + let arg = v.abs(); + assert_eq!(format!("{:.3}", arg), "5.000".to_owned()); + let arg = v.abs2(); + assert_eq!(format!("{:.3}", arg), "25.000".to_owned()); + let arg = v.logabs(); + assert_eq!(format!("{:.4}", arg), "1.6094".to_owned()); + + let v3 = v.add(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "4.7695 2.2849".to_owned() + ); + let v3 = v.sub(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.7695 4.2849".to_owned() + ); + let v3 = v.mul(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "7.0544 -0.4846".to_owned() + ); + let v3 = v.div(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.2423 3.5272".to_owned() + ); + let v3 = v.add_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "8.7695 3.2849".to_owned() + ); + let v3 = v.sub_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-1.2305 3.2849".to_owned() + ); + let v3 = v.mul_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "18.8476 16.4247".to_owned() + ); + let v3 = v.div_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.7539 0.6570".to_owned() + ); + + let v3 = v.add_imag(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "3.7695 8.2849".to_owned() + ); + let v3 = v.sub_imag(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "3.7695 -1.7151".to_owned() + ); + let v3 = v.mul_imag(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-16.4247 18.8476".to_owned() + ); + let v3 = v.div_imag(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.6570 -0.7539".to_owned() + ); + + let v3 = v.conjugate(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "3.7695 -3.2849".to_owned() + ); + let v3 = v.inverse(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1508 -0.1314".to_owned() + ); + let v3 = v.negative(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-3.7695 -3.2849".to_owned() + ); + let v3 = v.sqrt(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.0940 0.7844".to_owned() + ); + let v3 = C::sqrt_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.2361 0.0000".to_owned() + ); + + let v3 = v.pow(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "6.4240 -7.9737".to_owned() + ); + let v3 = v.pow_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-2824.0381 -1338.0708".to_owned() + ); + let v3 = v.exp(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-42.9142 -6.1938".to_owned() + ); + let v3 = v.ln(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.6094 0.7168".to_owned() + ); + let v3 = v.log10(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.6990 0.3113".to_owned() + ); + let v3 = v.log_b(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0071 2.0523".to_owned() + ); + let v3 = v.sin(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-7.8557 -10.7913".to_owned() + ); + let v3 = v.cos(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-10.8216 7.8337".to_owned() + ); + let v3 = v.tan(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.0027 0.9991".to_owned() + ); + + let v3 = v.sec(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0606 -0.0439".to_owned() + ); + let v3 = v.csc(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0441 0.0606".to_owned() + ); + let v3 = v.cot(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.0027 -1.0009".to_owned() + ); + let v3 = v.arcsin(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.8440 2.3014".to_owned() + ); + let v3 = C::arcsin_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.5708 -2.2924".to_owned() + ); + let v3 = v.arccos(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.7268 -2.3014".to_owned() + ); + let v3 = C::arccos_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.0000 2.2924".to_owned() + ); + let v3 = v.arctan(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.4186 0.1291".to_owned() + ); + let v3 = v.arcsec(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.4208 0.1325".to_owned() + ); + let v3 = C::arcsec_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.3694 0.0000".to_owned() + ); + let v3 = v.arccsc(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1500 -0.1325".to_owned() + ); + let v3 = C::arccsc_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.2014 0.0000".to_owned() + ); + let v3 = v.arccot(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1522 -0.1291".to_owned() + ); + let v3 = v.sinh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-21.4457 -3.0986".to_owned() + ); + let v3 = v.cosh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-21.4685 -3.0953".to_owned() + ); + let v3 = v.tanh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.9990 0.0003".to_owned() + ); + let v3 = v.sech(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0456 0.0066".to_owned() + ); + let v3 = v.csch(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0457 0.0066".to_owned() + ); + let v3 = v.coth(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.0010 -0.0003".to_owned() + ); + let v3 = v.arcsinh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.3041 0.7070".to_owned() + ); + let v3 = v.arccosh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.3014 0.7268".to_owned() + ); + let v3 = C::arccosh_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.2924 0.0000".to_owned() + ); + let v3 = v.arctanh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1493 1.4372".to_owned() + ); + let v3 = C::arctanh_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.2027 -1.5708".to_owned() + ); + let v3 = v.arcsech(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1325 -1.4208".to_owned() + ); + let v3 = v.arccsch(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1515 -0.1303".to_owned() + ); + let v3 = v.arccoth(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1493 -0.1336".to_owned() + ); + } + + #[test] + #[allow(deprecated)] + fn complex_f32() { + type C = num_complex::Complex; + use crate::complex::ComplexOps; + + let v = C::rect(10., 10.); + assert_eq!(v, C { re: 10., im: 10. }); + let v2 = C::rect(1., -1.); + assert_eq!(v2, C { re: 1., im: -1. }); + let v = C::polar(5., 7.); + assert_eq!( + format!("{:.4} {:.4}", v.re, v.im), + "3.7695 3.2849".to_owned() + ); + + let arg = v.arg(); + assert_eq!(format!("{:.4}", arg), "0.7168".to_owned()); + let arg = v.abs(); + assert_eq!(format!("{:.3}", arg), "5.000".to_owned()); + let arg = v.abs2(); + assert_eq!(format!("{:.3}", arg), "25.000".to_owned()); + let arg = v.logabs(); + assert_eq!(format!("{:.4}", arg), "1.6094".to_owned()); + + let v3 = v.add(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "4.7695 2.2849".to_owned() + ); + let v3 = v.sub(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.7695 4.2849".to_owned() + ); + let v3 = v.mul(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "7.0544 -0.4846".to_owned() + ); + let v3 = v.div(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.2423 3.5272".to_owned() + ); + let v3 = v.add_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "8.7695 3.2849".to_owned() + ); + let v3 = v.sub_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-1.2305 3.2849".to_owned() + ); + let v3 = v.mul_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "18.8476 16.4247".to_owned() + ); + let v3 = v.div_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.7539 0.6570".to_owned() + ); + + let v3 = v.add_imag(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "3.7695 8.2849".to_owned() + ); + let v3 = v.sub_imag(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "3.7695 -1.7151".to_owned() + ); + let v3 = v.mul_imag(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-16.4247 18.8476".to_owned() + ); + let v3 = v.div_imag(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.6570 -0.7539".to_owned() + ); + + let v3 = v.conjugate(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "3.7695 -3.2849".to_owned() + ); + let v3 = v.inverse(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1508 -0.1314".to_owned() + ); + let v3 = v.negative(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-3.7695 -3.2849".to_owned() + ); + let v3 = v.sqrt(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.0940 0.7844".to_owned() + ); + let v3 = C::sqrt_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.2361 0.0000".to_owned() + ); + + let v3 = v.pow(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "6.4240 -7.9737".to_owned() + ); + let v3 = v.pow_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-2824.0381 -1338.0712".to_owned() + ); + let v3 = v.exp(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-42.9142 -6.1938".to_owned() + ); + let v3 = v.ln(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.6094 0.7168".to_owned() + ); + let v3 = v.log10(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.6990 0.3113".to_owned() + ); + let v3 = v.log_b(&v2); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0071 2.0523".to_owned() + ); + let v3 = v.sin(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-7.8557 -10.7913".to_owned() + ); + let v3 = v.cos(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-10.8216 7.8337".to_owned() + ); + let v3 = v.tan(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.0027 0.9991".to_owned() + ); + + let v3 = v.sec(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0606 -0.0439".to_owned() + ); + let v3 = v.csc(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0441 0.0606".to_owned() + ); + let v3 = v.cot(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.0027 -1.0009".to_owned() + ); + let v3 = v.arcsin(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.8440 2.3014".to_owned() + ); + let v3 = C::arcsin_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.5708 -2.2924".to_owned() + ); + let v3 = v.arccos(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.7268 -2.3014".to_owned() + ); + let v3 = C::arccos_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.0000 2.2924".to_owned() + ); + let v3 = v.arctan(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.4186 0.1291".to_owned() + ); + let v3 = v.arcsec(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.4208 0.1325".to_owned() + ); + let v3 = C::arcsec_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.3694 0.0000".to_owned() + ); + let v3 = v.arccsc(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1500 -0.1325".to_owned() + ); + let v3 = C::arccsc_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.2014 0.0000".to_owned() + ); + let v3 = v.arccot(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1522 -0.1291".to_owned() + ); + let v3 = v.sinh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-21.4457 -3.0986".to_owned() + ); + let v3 = v.cosh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-21.4685 -3.0953".to_owned() + ); + let v3 = v.tanh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.9990 0.0003".to_owned() + ); + let v3 = v.sech(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0456 0.0066".to_owned() + ); + let v3 = v.csch(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "-0.0457 0.0066".to_owned() + ); + let v3 = v.coth(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "1.0010 -0.0003".to_owned() + ); + let v3 = v.arcsinh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.3041 0.7070".to_owned() + ); + let v3 = v.arccosh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.3014 0.7268".to_owned() + ); + let v3 = C::arccosh_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "2.2924 0.0000".to_owned() + ); + let v3 = v.arctanh(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1493 1.4372".to_owned() + ); + let v3 = C::arctanh_real(5.); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.2027 -1.5708".to_owned() + ); + let v3 = v.arcsech(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1325 -1.4208".to_owned() + ); + let v3 = v.arccsch(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1515 -0.1303".to_owned() + ); + let v3 = v.arccoth(); + assert_eq!( + format!("{:.4} {:.4}", v3.re, v3.im), + "0.1493 -0.1336".to_owned() + ); } } - -// All these tests have been tested against the following C code: -// -// ```ignore -// #include -// #include -// #include -// #include -// -// void print_complex(gsl_complex *c, const char *text) { -// printf("%s: %f %f\n", text, c->dat[0], c->dat[1]); -// } -// -// int main (void) -// { -// gsl_complex c = gsl_complex_rect(10., 10.); -// gsl_complex c2 = gsl_complex_rect(1., -1.); -// print_complex(&c, "rect"); -// print_complex(&c2, "rect"); -// gsl_complex c3 = gsl_complex_polar(5., 7.); -// print_complex(&c3, "polar"); -// -// printf("-> %f\n", gsl_complex_arg(c3)); -// printf("-> %f\n", gsl_complex_abs(c3)); -// printf("-> %f\n", gsl_complex_abs2(c3)); -// printf("-> %f\n", gsl_complex_logabs(c3)); -// -// { -// gsl_complex c4 = gsl_complex_add(c3, c2); -// print_complex(&c4, "\nadd"); -// } -// { -// gsl_complex c4 = gsl_complex_sub(c3, c2); -// print_complex(&c4, "sub"); -// } -// { -// gsl_complex c4 = gsl_complex_mul(c3, c2); -// print_complex(&c4, "mul"); -// } -// { -// gsl_complex c4 = gsl_complex_div(c3, c2); -// print_complex(&c4, "div"); -// } -// { -// gsl_complex c4 = gsl_complex_add_real(c3, 5.); -// print_complex(&c4, "add_real"); -// } -// { -// gsl_complex c4 = gsl_complex_sub_real(c3, 5.); -// print_complex(&c4, "sub_real"); -// } -// { -// gsl_complex c4 = gsl_complex_mul_real(c3, 5.); -// print_complex(&c4, "mul_real"); -// } -// { -// gsl_complex c4 = gsl_complex_div_real(c3, 5.); -// print_complex(&c4, "div_real"); -// } -// -// -// { -// gsl_complex c4 = gsl_complex_add_imag(c3, 5.); -// print_complex(&c4, "\nadd_imag"); -// } -// { -// gsl_complex c4 = gsl_complex_sub_imag(c3, 5.); -// print_complex(&c4, "sub_imag"); -// } -// { -// gsl_complex c4 = gsl_complex_mul_imag(c3, 5.); -// print_complex(&c4, "mul_imag"); -// } -// { -// gsl_complex c4 = gsl_complex_div_imag(c3, 5.); -// print_complex(&c4, "div_imag"); -// } -// -// -// { -// gsl_complex c4 = gsl_complex_conjugate(c3); -// print_complex(&c4, "\nconjugate"); -// } -// { -// gsl_complex c4 = gsl_complex_inverse(c3); -// print_complex(&c4, "inverse"); -// } -// { -// gsl_complex c4 = gsl_complex_negative(c3); -// print_complex(&c4, "negative"); -// } -// { -// gsl_complex c4 = gsl_complex_sqrt(c3); -// print_complex(&c4, "sqrt"); -// } -// { -// gsl_complex c4 = gsl_complex_sqrt_real(5.); -// print_complex(&c4, "sqrt_real"); -// } -// -// -// { -// gsl_complex c4 = gsl_complex_pow(c3, c2); -// print_complex(&c4, "\npow"); -// } -// { -// gsl_complex c4 = gsl_complex_pow_real(c3, 5.); -// print_complex(&c4, "pow_real"); -// } -// { -// gsl_complex c4 = gsl_complex_exp(c3); -// print_complex(&c4, "exp"); -// } -// { -// gsl_complex c4 = gsl_complex_log(c3); -// print_complex(&c4, "log"); -// } -// { -// gsl_complex c4 = gsl_complex_log10(c3); -// print_complex(&c4, "log10"); -// } -// { -// gsl_complex c4 = gsl_complex_log_b(c3, c2); -// print_complex(&c4, "log_b"); -// } -// { -// gsl_complex c4 = gsl_complex_sin(c3); -// print_complex(&c4, "sin"); -// } -// { -// gsl_complex c4 = gsl_complex_cos(c3); -// print_complex(&c4, "cos"); -// } -// { -// gsl_complex c4 = gsl_complex_tan(c3); -// print_complex(&c4, "tan"); -// } -// -// -// { -// gsl_complex c4 = gsl_complex_sec(c3); -// print_complex(&c4, "\nsec"); -// } -// { -// gsl_complex c4 = gsl_complex_csc(c3); -// print_complex(&c4, "csc"); -// } -// { -// gsl_complex c4 = gsl_complex_cot(c3); -// print_complex(&c4, "cot"); -// } -// { -// gsl_complex c4 = gsl_complex_arcsin(c3); -// print_complex(&c4, "arcsin"); -// } -// { -// gsl_complex c4 = gsl_complex_arcsin_real(5.); -// print_complex(&c4, "arcsin_real"); -// } -// { -// gsl_complex c4 = gsl_complex_arccos(c3); -// print_complex(&c4, "arccos"); -// } -// { -// gsl_complex c4 = gsl_complex_arccos_real(5.); -// print_complex(&c4, "arccos_real"); -// } -// { -// gsl_complex c4 = gsl_complex_arctan(c3); -// print_complex(&c4, "arctan"); -// } -// { -// gsl_complex c4 = gsl_complex_arcsec(c3); -// print_complex(&c4, "arcsec"); -// } -// { -// gsl_complex c4 = gsl_complex_arcsec_real(5.); -// print_complex(&c4, "arcsec_real"); -// } -// { -// gsl_complex c4 = gsl_complex_arccsc(c3); -// print_complex(&c4, "arccsc"); -// } -// { -// gsl_complex c4 = gsl_complex_arccsc_real(5.); -// print_complex(&c4, "arccsc_real"); -// } -// { -// gsl_complex c4 = gsl_complex_arccot(c3); -// print_complex(&c4, "arccot"); -// } -// { -// gsl_complex c4 = gsl_complex_sinh(c3); -// print_complex(&c4, "sinh"); -// } -// { -// gsl_complex c4 = gsl_complex_cosh(c3); -// print_complex(&c4, "cosh"); -// } -// { -// gsl_complex c4 = gsl_complex_tanh(c3); -// print_complex(&c4, "tanh"); -// } -// { -// gsl_complex c4 = gsl_complex_sech(c3); -// print_complex(&c4, "sech"); -// } -// { -// gsl_complex c4 = gsl_complex_csch(c3); -// print_complex(&c4, "csch"); -// } -// { -// gsl_complex c4 = gsl_complex_coth(c3); -// print_complex(&c4, "coth"); -// } -// { -// gsl_complex c4 = gsl_complex_arcsinh(c3); -// print_complex(&c4, "arcsinh"); -// } -// { -// gsl_complex c4 = gsl_complex_arccosh(c3); -// print_complex(&c4, "arccosh"); -// } -// { -// gsl_complex c4 = gsl_complex_arccosh_real(5.); -// print_complex(&c4, "arccosh_real"); -// } -// { -// gsl_complex c4 = gsl_complex_arctanh(c3); -// print_complex(&c4, "arctanh"); -// } -// { -// gsl_complex c4 = gsl_complex_arctanh_real(5.); -// print_complex(&c4, "arctanh_real"); -// } -// { -// gsl_complex c4 = gsl_complex_arcsech(c3); -// print_complex(&c4, "arcsech"); -// } -// { -// gsl_complex c4 = gsl_complex_arccsch(c3); -// print_complex(&c4, "arccsch"); -// } -// { -// gsl_complex c4 = gsl_complex_arccoth(c3); -// print_complex(&c4, "arccoth"); -// } -// return 0; -// } -// ``` -#[test] -fn complex_f64() { - let v = ComplexF64::rect(10., 10.); - assert_eq!(v, ComplexF64 { dat: [10., 10.] }); - let v2 = ComplexF64::rect(1., -1.); - assert_eq!(v2, ComplexF64 { dat: [1., -1.] }); - let v = ComplexF64::polar(5., 7.); - assert_eq!( - format!("{:.4} {:.4}", v.dat[0], v.dat[1]), - "3.7695 3.2849".to_owned() - ); - - let arg = v.arg(); - assert_eq!(format!("{:.4}", arg), "0.7168".to_owned()); - let arg = v.abs(); - assert_eq!(format!("{:.3}", arg), "5.000".to_owned()); - let arg = v.abs2(); - assert_eq!(format!("{:.3}", arg), "25.000".to_owned()); - let arg = v.logabs(); - assert_eq!(format!("{:.4}", arg), "1.6094".to_owned()); - - let v3 = v.add(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "4.7695 2.2849".to_owned() - ); - let v3 = v.sub(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.7695 4.2849".to_owned() - ); - let v3 = v.mul(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "7.0544 -0.4846".to_owned() - ); - let v3 = v.div(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.2423 3.5272".to_owned() - ); - let v3 = v.add_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "8.7695 3.2849".to_owned() - ); - let v3 = v.sub_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-1.2305 3.2849".to_owned() - ); - let v3 = v.mul_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "18.8476 16.4247".to_owned() - ); - let v3 = v.div_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.7539 0.6570".to_owned() - ); - - let v3 = v.add_imag(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "3.7695 8.2849".to_owned() - ); - let v3 = v.sub_imag(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "3.7695 -1.7151".to_owned() - ); - let v3 = v.mul_imag(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-16.4247 18.8476".to_owned() - ); - let v3 = v.div_imag(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.6570 -0.7539".to_owned() - ); - - let v3 = v.conjugate(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "3.7695 -3.2849".to_owned() - ); - let v3 = v.inverse(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1508 -0.1314".to_owned() - ); - let v3 = v.negative(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-3.7695 -3.2849".to_owned() - ); - let v3 = v.sqrt(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.0940 0.7844".to_owned() - ); - let v3 = ComplexF64::sqrt_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.2361 0.0000".to_owned() - ); - - let v3 = v.pow(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "6.4240 -7.9737".to_owned() - ); - let v3 = v.pow_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-2824.0381 -1338.0708".to_owned() - ); - let v3 = v.exp(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-42.9142 -6.1938".to_owned() - ); - let v3 = v.log(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.6094 0.7168".to_owned() - ); - let v3 = v.log10(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.6990 0.3113".to_owned() - ); - let v3 = v.log_b(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0071 2.0523".to_owned() - ); - let v3 = v.sin(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-7.8557 -10.7913".to_owned() - ); - let v3 = v.cos(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-10.8216 7.8337".to_owned() - ); - let v3 = v.tan(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.0027 0.9991".to_owned() - ); - - let v3 = v.sec(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0606 -0.0439".to_owned() - ); - let v3 = v.csc(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0441 0.0606".to_owned() - ); - let v3 = v.cot(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.0027 -1.0009".to_owned() - ); - let v3 = v.arcsin(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.8440 2.3014".to_owned() - ); - let v3 = ComplexF64::arcsin_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.5708 -2.2924".to_owned() - ); - let v3 = v.arccos(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.7268 -2.3014".to_owned() - ); - let v3 = ComplexF64::arccos_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.0000 2.2924".to_owned() - ); - let v3 = v.arctan(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.4186 0.1291".to_owned() - ); - let v3 = v.arcsec(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.4208 0.1325".to_owned() - ); - let v3 = ComplexF64::arcsec_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.3694 0.0000".to_owned() - ); - let v3 = v.arccsc(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1500 -0.1325".to_owned() - ); - let v3 = ComplexF64::arccsc_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.2014 0.0000".to_owned() - ); - let v3 = v.arccot(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1522 -0.1291".to_owned() - ); - let v3 = v.sinh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-21.4457 -3.0986".to_owned() - ); - let v3 = v.cosh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-21.4685 -3.0953".to_owned() - ); - let v3 = v.tanh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.9990 0.0003".to_owned() - ); - let v3 = v.sech(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0456 0.0066".to_owned() - ); - let v3 = v.csch(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0457 0.0066".to_owned() - ); - let v3 = v.coth(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.0010 -0.0003".to_owned() - ); - let v3 = v.arcsinh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.3041 0.7070".to_owned() - ); - let v3 = v.arccosh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.3014 0.7268".to_owned() - ); - let v3 = ComplexF64::arccosh_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.2924 0.0000".to_owned() - ); - let v3 = v.arctanh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1493 1.4372".to_owned() - ); - let v3 = ComplexF64::arctanh_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.2027 -1.5708".to_owned() - ); - let v3 = v.arcsech(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1325 -1.4208".to_owned() - ); - let v3 = v.arccsch(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1515 -0.1303".to_owned() - ); - let v3 = v.arccoth(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1493 -0.1336".to_owned() - ); -} - -#[test] -fn complex_f32() { - let v = ComplexF32::rect(10., 10.); - assert_eq!(v, ComplexF32 { dat: [10., 10.] }); - let v2 = ComplexF32::rect(1., -1.); - assert_eq!(v2, ComplexF32 { dat: [1., -1.] }); - let v = ComplexF32::polar(5., 7.); - assert_eq!( - format!("{:.4} {:.4}", v.dat[0], v.dat[1]), - "3.7695 3.2849".to_owned() - ); - - let arg = v.arg(); - assert_eq!(format!("{:.4}", arg), "0.7168".to_owned()); - let arg = v.abs(); - assert_eq!(format!("{:.3}", arg), "5.000".to_owned()); - let arg = v.abs2(); - assert_eq!(format!("{:.3}", arg), "25.000".to_owned()); - let arg = v.logabs(); - assert_eq!(format!("{:.4}", arg), "1.6094".to_owned()); - - let v3 = v.add(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "4.7695 2.2849".to_owned() - ); - let v3 = v.sub(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.7695 4.2849".to_owned() - ); - let v3 = v.mul(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "7.0544 -0.4846".to_owned() - ); - let v3 = v.div(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.2423 3.5272".to_owned() - ); - let v3 = v.add_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "8.7695 3.2849".to_owned() - ); - let v3 = v.sub_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-1.2305 3.2849".to_owned() - ); - let v3 = v.mul_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "18.8476 16.4247".to_owned() - ); - let v3 = v.div_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.7539 0.6570".to_owned() - ); - - let v3 = v.add_imag(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "3.7695 8.2849".to_owned() - ); - let v3 = v.sub_imag(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "3.7695 -1.7151".to_owned() - ); - let v3 = v.mul_imag(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-16.4247 18.8476".to_owned() - ); - let v3 = v.div_imag(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.6570 -0.7539".to_owned() - ); - - let v3 = v.conjugate(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "3.7695 -3.2849".to_owned() - ); - let v3 = v.inverse(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1508 -0.1314".to_owned() - ); - let v3 = v.negative(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-3.7695 -3.2849".to_owned() - ); - let v3 = v.sqrt(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.0940 0.7844".to_owned() - ); - let v3 = ComplexF32::sqrt_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.2361 0.0000".to_owned() - ); - - let v3 = v.pow(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "6.4240 -7.9737".to_owned() - ); - let v3 = v.pow_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-2824.0381 -1338.0712".to_owned() - ); - let v3 = v.exp(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-42.9142 -6.1938".to_owned() - ); - let v3 = v.log(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.6094 0.7168".to_owned() - ); - let v3 = v.log10(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.6990 0.3113".to_owned() - ); - let v3 = v.log_b(&v2); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0071 2.0523".to_owned() - ); - let v3 = v.sin(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-7.8557 -10.7913".to_owned() - ); - let v3 = v.cos(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-10.8216 7.8337".to_owned() - ); - let v3 = v.tan(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.0027 0.9991".to_owned() - ); - - let v3 = v.sec(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0606 -0.0439".to_owned() - ); - let v3 = v.csc(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0441 0.0606".to_owned() - ); - let v3 = v.cot(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.0027 -1.0009".to_owned() - ); - let v3 = v.arcsin(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.8440 2.3014".to_owned() - ); - let v3 = ComplexF32::arcsin_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.5708 -2.2924".to_owned() - ); - let v3 = v.arccos(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.7268 -2.3014".to_owned() - ); - let v3 = ComplexF32::arccos_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.0000 2.2924".to_owned() - ); - let v3 = v.arctan(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.4186 0.1291".to_owned() - ); - let v3 = v.arcsec(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.4208 0.1325".to_owned() - ); - let v3 = ComplexF32::arcsec_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.3694 0.0000".to_owned() - ); - let v3 = v.arccsc(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1500 -0.1325".to_owned() - ); - let v3 = ComplexF32::arccsc_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.2014 0.0000".to_owned() - ); - let v3 = v.arccot(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1522 -0.1291".to_owned() - ); - let v3 = v.sinh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-21.4457 -3.0986".to_owned() - ); - let v3 = v.cosh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-21.4685 -3.0953".to_owned() - ); - let v3 = v.tanh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.9990 0.0003".to_owned() - ); - let v3 = v.sech(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0456 0.0066".to_owned() - ); - let v3 = v.csch(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "-0.0457 0.0066".to_owned() - ); - let v3 = v.coth(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "1.0010 -0.0003".to_owned() - ); - let v3 = v.arcsinh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.3041 0.7070".to_owned() - ); - let v3 = v.arccosh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.3014 0.7268".to_owned() - ); - let v3 = ComplexF32::arccosh_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "2.2924 0.0000".to_owned() - ); - let v3 = v.arctanh(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1493 1.4372".to_owned() - ); - let v3 = ComplexF32::arctanh_real(5.); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.2027 -1.5708".to_owned() - ); - let v3 = v.arcsech(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1325 -1.4208".to_owned() - ); - let v3 = v.arccsch(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1515 -0.1303".to_owned() - ); - let v3 = v.arccoth(); - assert_eq!( - format!("{:.4} {:.4}", v3.dat[0], v3.dat[1]), - "0.1493 -0.1336".to_owned() - ); -} diff --git a/src/types/eigen_symmetric_workspace.rs b/src/types/eigen_symmetric_workspace.rs index 246408a4..64d2aa24 100644 --- a/src/types/eigen_symmetric_workspace.rs +++ b/src/types/eigen_symmetric_workspace.rs @@ -890,17 +890,18 @@ fn eigen_symmetric_vworkspace() { // ``` #[test] fn eigen_hermitian_workspace() { - use crate::ComplexF64; use MatrixComplexF64; use VectorF64; + use num_complex::Complex; + use crate::complex::ComplexOps; let mut e = EigenHermitianWorkspace::new(3).unwrap(); let mut m = MatrixComplexF64::new(2, 2).unwrap(); - m.set(0, 0, &ComplexF64::rect(5., 5.)); - m.set(0, 1, &ComplexF64::rect(1., 4.)); - m.set(1, 0, &ComplexF64::rect(2., 3.)); - m.set(1, 1, &ComplexF64::rect(5., 7.)); + m.set(0, 0, &Complex::::rect(5., 5.)); + m.set(0, 1, &Complex::::rect(1., 4.)); + m.set(1, 0, &Complex::::rect(2., 3.)); + m.set(1, 1, &Complex::::rect(5., 7.)); let mut v = VectorF64::new(2).unwrap(); e.herm(&mut m, &mut v).unwrap(); @@ -944,15 +945,16 @@ fn eigen_hermitian_workspace() { // ``` #[test] fn eigen_hermitian_vworkspace() { - use crate::ComplexF64; + use num_complex::Complex; + use crate::complex::ComplexOps; let mut e = EigenHermitianVWorkspace::new(3).unwrap(); let mut m = MatrixComplexF64::new(2, 2).unwrap(); - m.set(0, 0, &ComplexF64::rect(5., 5.)); - m.set(0, 1, &ComplexF64::rect(1., 4.)); - m.set(1, 0, &ComplexF64::rect(2., 3.)); - m.set(1, 1, &ComplexF64::rect(5., 7.)); + m.set(0, 0, &Complex::::rect(5., 5.)); + m.set(0, 1, &Complex::::rect(1., 4.)); + m.set(1, 0, &Complex::::rect(2., 3.)); + m.set(1, 1, &Complex::::rect(5., 7.)); let mut v = VectorF64::new(2).unwrap(); let mut m2 = MatrixComplexF64::new(2, 2).unwrap(); @@ -961,20 +963,20 @@ fn eigen_hermitian_vworkspace() { assert_eq!( &format!( "({:.4}, {:.4}) ({:.4}, {:.4})", - m2.get(0, 0).dat[0], - m2.get(0, 0).dat[1], - m2.get(0, 1).dat[0], - m2.get(0, 1).dat[1] + m2.get(0, 0).re, + m2.get(0, 0).im, + m2.get(0, 1).re, + m2.get(0, 1).im ), "(0.7071, 0.0000) (0.7071, 0.0000)" ); assert_eq!( &format!( "({:.4}, {:.4}) ({:.4}, {:.4})", - m2.get(1, 0).dat[0], - m2.get(1, 0).dat[1], - m2.get(1, 1).dat[0], - m2.get(1, 1).dat[1] + m2.get(1, 0).re, + m2.get(1, 0).im, + m2.get(1, 1).re, + m2.get(1, 1).im ), "(0.3922, 0.5883) (-0.3922, -0.5883)" ); diff --git a/src/types/matrix_complex.rs b/src/types/matrix_complex.rs index ba6643ac..5a0c0b4f 100644 --- a/src/types/matrix_complex.rs +++ b/src/types/matrix_complex.rs @@ -3,15 +3,19 @@ // use crate::ffi::FFI; -use crate::Value; +use crate::{ + complex::{FromC, ToC}, + Value, +}; use paste::paste; +use num_complex::Complex; use std::fmt::{self, Debug, Formatter}; macro_rules! gsl_matrix_complex { - ($rust_name:ident, $name:ident, $complex:ident, $complex_c:ident) => ( + ($rust_name:ident, $name:ident, $complex: ty, $complex_c:ident) => ( paste! { -use crate::types::{$complex, [], []}; +use crate::types::{[], []}; ffi_wrapper!( $rust_name, @@ -49,7 +53,8 @@ impl $rust_name { /// invoked and 0 is returned. #[doc(alias = $name _get)] pub fn get(&self, y: usize, x: usize) -> $complex { - unsafe { std::mem::transmute(sys::[<$name _get>](self.unwrap_shared(), y, x)) } + // FIXME: check that i, j are not out of bounds. + unsafe { sys::[<$name _get>](self.unwrap_shared(), y, x).wrap() } } /// This function sets the value of the (i,j)-th element of the matrix to value. @@ -58,7 +63,7 @@ impl $rust_name { #[doc(alias = $name _set)] pub fn set(&mut self, y: usize, x: usize, value: &$complex) -> &Self { unsafe { - sys::[<$name _set>](self.unwrap_unique(), y, x, std::mem::transmute(*value)) + sys::[<$name _set>](self.unwrap_unique(), y, x, value.unwrap()) }; self } @@ -66,7 +71,7 @@ impl $rust_name { /// This function sets all the elements of the matrix to the value x. #[doc(alias = $name _set_all)] pub fn set_all(&mut self, x: &$complex) -> &Self { - unsafe { sys::[<$name _set_all>](self.unwrap_unique(), std::mem::transmute(*x)) }; + unsafe { sys::[<$name _set_all>](self.unwrap_unique(), x.unwrap()) }; self } @@ -260,7 +265,7 @@ impl $rust_name { #[doc(alias = $name _scale)] pub fn scale(&mut self, x: &$complex) -> Result<(), Value> { let ret = unsafe { - sys::[<$name _scale>](self.unwrap_unique(), std::mem::transmute(*x)) + sys::[<$name _scale>](self.unwrap_unique(), x.unwrap()) }; result_handler!(ret, ()) } @@ -270,7 +275,7 @@ impl $rust_name { #[doc(alias = $name _add_constant)] pub fn add_constant(&mut self, x: &$complex) -> Result<(), Value> { let ret = unsafe { - sys::[<$name _add_constant>](self.unwrap_unique(), std::mem::transmute(*x)) + sys::[<$name _add_constant>](self.unwrap_unique(), x.unwrap()) }; result_handler!(ret, ()) } @@ -402,12 +407,15 @@ impl Debug for $rust_name { ); // end of macro block } +type ComplexF64 = Complex; gsl_matrix_complex!( MatrixComplexF64, gsl_matrix_complex, ComplexF64, gsl_vector_complex ); + +type ComplexF32 = Complex; gsl_matrix_complex!( MatrixComplexF32, gsl_matrix_complex_float, diff --git a/src/types/mod.rs b/src/types/mod.rs index d1153d14..e14dfbcc 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -8,7 +8,7 @@ pub use self::basis_spline::BSpLineWorkspace; pub use self::chebyshev::ChebSeries; pub use self::combination::Combination; -pub use self::complex::{ComplexF32, ComplexF64}; +pub use self::complex::{ComplexF32, ComplexF64, ComplexOps}; pub use self::discrete_hankel::DiscreteHankel; pub use self::eigen_symmetric_workspace::{ EigenGenHermVWorkspace, EigenGenHermWorkspace, EigenGenSymmVWorkspace, EigenGenSymmWorkspace, diff --git a/src/types/vector_complex.rs b/src/types/vector_complex.rs index bf08d164..6c789058 100644 --- a/src/types/vector_complex.rs +++ b/src/types/vector_complex.rs @@ -3,19 +3,20 @@ // use crate::ffi::FFI; -use crate::Value; +use crate::{ + complex::{FromC, ToC}, + Value, +}; use paste::paste; use std::{ fmt::{self, Debug, Formatter}, marker::PhantomData, }; +use num_complex::Complex; macro_rules! gsl_vec_complex { ($rust_name:ident, $name:ident, $complex:ident, $rust_ty:ident) => { - paste! { - - use crate::types::$complex; - + paste! { pub struct $rust_name { vec: *mut sys::$name, can_free: bool, @@ -130,7 +131,7 @@ macro_rules! gsl_vec_complex { /// 0 to n-1 then the error handler is invoked and 0 is returned. #[doc(alias = $name _get)] pub fn get(&self, i: usize) -> $complex { - unsafe { std::mem::transmute(sys::[<$name _get>](self.unwrap_shared(), i)) } + unsafe { sys::[<$name _get>](self.unwrap_shared(), i).wrap() } } /// This function sets the value of the i-th element of a vector v to x. If i lies outside the @@ -138,7 +139,7 @@ macro_rules! gsl_vec_complex { #[doc(alias = $name _set)] pub fn set(&mut self, i: usize, x: &$complex) -> &Self { unsafe { - sys::[<$name _set>](self.unwrap_unique(), i, std::mem::transmute(*x)) + sys::[<$name _set>](self.unwrap_unique(), i, x.unwrap()) }; self } @@ -147,7 +148,7 @@ macro_rules! gsl_vec_complex { #[doc(alias = $name _set_all)] pub fn set_all(&mut self, x: &$complex) -> &Self { unsafe { - sys::[<$name _set_all>](self.unwrap_unique(), std::mem::transmute(*x)) + sys::[<$name _set_all>](self.unwrap_unique(), x.unwrap()) }; self } @@ -262,7 +263,7 @@ macro_rules! gsl_vec_complex { #[doc(alias = $name _scale)] pub fn scale(&mut self, x: &$complex) -> Result<(), Value> { let ret = unsafe { - sys::[<$name _scale>](self.unwrap_unique(), std::mem::transmute(*x)) + sys::[<$name _scale>](self.unwrap_unique(), x.unwrap()) }; result_handler!(ret, ()) } @@ -274,7 +275,7 @@ macro_rules! gsl_vec_complex { let ret = unsafe { sys::[<$name _add_constant>]( self.unwrap_unique(), - std::mem::transmute(*x), + x.unwrap(), ) }; result_handler!(ret, ()) @@ -510,5 +511,7 @@ macro_rules! gsl_vec_complex { }; // end of macro block } -gsl_vec_complex!(VectorComplexF32, gsl_vector_complex_float, ComplexF32, f32); -gsl_vec_complex!(VectorComplexF64, gsl_vector_complex, ComplexF64, f64); +type C32 = Complex; +gsl_vec_complex!(VectorComplexF32, gsl_vector_complex_float, C32, f32); +type C64 = Complex; +gsl_vec_complex!(VectorComplexF64, gsl_vector_complex, C64, f64); From b90f63e38344ae343885ab2fc6c4b8e7a4af186e Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Wed, 25 Dec 2024 23:45:46 +0100 Subject: [PATCH 18/28] Make the "fft" module use the Vector trait --- examples/fftmr.rs | 5 +- src/fft.rs | 120 ++++++++++++++++++--------- src/types/fast_fourier_transforms.rs | 24 +++++- 3 files changed, 104 insertions(+), 45 deletions(-) diff --git a/examples/fftmr.rs b/examples/fftmr.rs index 06690841..460473e6 100644 --- a/examples/fftmr.rs +++ b/examples/fftmr.rs @@ -6,6 +6,7 @@ extern crate rgsl; use rgsl::{FftComplexF64WaveTable, FftComplexF64Workspace}; +// FIXME: Make the interface use complex numbers. macro_rules! real { ($z:ident, $i:expr) => { $z[2 * ($i)] @@ -17,7 +18,7 @@ macro_rules! imag { }; } -const N: usize = 630; +const N: usize = 128; fn main() { let data = &mut [0.; 2 * N]; @@ -40,7 +41,7 @@ fn main() { println!("# factor {}: {}", i, wavetable.factor()[i]); } - workspace.forward(data, 1, N, &wavetable).unwrap(); + workspace.forward(data, &wavetable).unwrap(); for i in 0..N { println!("{}: {} {}", i, real!(data, i), imag!(data, i)); diff --git a/src/fft.rs b/src/fft.rs index cb1fefd6..f59f57d5 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -165,7 +165,10 @@ is desirable for better locality of memory accesses). /// /// The functions return a value of crate::Value::Success if no errors were detected, or Value::Dom if the length n is not a power of two. pub mod radix2 { - use crate::Value; + use crate::{ + Value, + vector::VectorMut, + }; #[doc(alias = "gsl_fft_complex_radix2_forward")] pub fn forward(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { @@ -174,62 +177,85 @@ pub mod radix2 { } #[doc(alias = "gsl_fft_complex_radix2_transform")] - pub fn transform( - data: &mut [f64], - stride: usize, - n: usize, + pub fn transform + ?Sized>( + data: &mut V, sign: crate::FftDirection, ) -> Result<(), Value> { let ret = unsafe { - sys::gsl_fft_complex_radix2_transform(data.as_mut_ptr(), stride, n, sign.into()) + sys::gsl_fft_complex_radix2_transform( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + sign.into()) }; result_handler!(ret, ()) } #[doc(alias = "gsl_fft_complex_radix2_backward")] - pub fn backward(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_complex_radix2_backward(data.as_mut_ptr(), stride, n) }; + pub fn backward(data: &mut V) -> Result<(), Value> + where V: VectorMut + ?Sized { + let ret = unsafe { sys::gsl_fft_complex_radix2_backward( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data)) }; result_handler!(ret, ()) } #[doc(alias = "gsl_fft_complex_radix2_inverse")] - pub fn inverse(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_complex_radix2_inverse(data.as_mut_ptr(), stride, n) }; + pub fn inverse(data: &mut V) -> Result<(), Value> + where V: VectorMut + ?Sized { + let ret = unsafe { sys::gsl_fft_complex_radix2_inverse( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data)) }; result_handler!(ret, ()) } /// This is decimation-in-frequency version of the radix-2 FFT function. #[doc(alias = "gsl_fft_complex_radix2_dif_forward")] - pub fn dif_forward(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_complex_radix2_dif_forward(data.as_mut_ptr(), stride, n) }; + pub fn dif_forward(data: &mut V) -> Result<(), Value> + where V: VectorMut + ?Sized { + let ret = unsafe { sys::gsl_fft_complex_radix2_dif_forward( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data)) }; result_handler!(ret, ()) } /// This is decimation-in-frequency version of the radix-2 FFT function. #[doc(alias = "gsl_fft_complex_radix2_dif_transform")] - pub fn dif_transform( - data: &mut [f64], - stride: usize, - n: usize, + pub fn dif_transform + ?Sized>( + data: &mut V, sign: crate::FftDirection, ) -> Result<(), Value> { let ret = unsafe { - sys::gsl_fft_complex_radix2_dif_transform(data.as_mut_ptr(), stride, n, sign.into()) - }; + sys::gsl_fft_complex_radix2_dif_transform( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + sign.into()) }; result_handler!(ret, ()) } /// This is decimation-in-frequency version of the radix-2 FFT function. #[doc(alias = "gsl_fft_complex_radix2_dif_backward")] - pub fn dif_backward(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_complex_radix2_dif_backward(data.as_mut_ptr(), stride, n) }; + pub fn dif_backward(data: &mut V) -> Result<(), Value> + where V: VectorMut + ?Sized { + let ret = unsafe { sys::gsl_fft_complex_radix2_dif_backward( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data)) }; result_handler!(ret, ()) } /// This is decimation-in-frequency version of the radix-2 FFT function. #[doc(alias = "gsl_fft_complex_radix2_dif_inverse")] - pub fn dif_inverse(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_complex_radix2_dif_inverse(data.as_mut_ptr(), stride, n) }; + pub fn dif_inverse(data: &mut V) -> Result<(), Value> + where V: VectorMut + ?Sized { + let ret = unsafe { sys::gsl_fft_complex_radix2_dif_inverse( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data)) }; result_handler!(ret, ()) } } @@ -237,7 +263,10 @@ pub mod radix2 { /// This section describes radix-2 FFT algorithms for real data. They use the Cooley-Tukey algorithm to compute in-place FFTs for lengths which /// are a power of 2. pub mod real_radix2 { - use crate::Value; + use crate::{ + Value, + vector::{check_equal_len, Vector, VectorMut}, + }; /// This function computes an in-place radix-2 FFT of length n and stride stride on the real array data. The output is a half-complex sequence, /// which is stored in-place. The arrangement of the half-complex terms uses the following scheme: for k < n/2 the real part of the k-th term @@ -270,24 +299,36 @@ pub mod real_radix2 { /// Note that the output data can be converted into the full complex sequence using the function gsl_fft_halfcomplex_radix2_unpack described /// below. #[doc(alias = "gsl_fft_real_radix2_transform")] - pub fn transform(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_real_radix2_transform(data.as_mut_ptr(), stride, n) }; + pub fn transform(data: &mut V) -> Result<(), Value> + where V: VectorMut + ?Sized { + let ret = unsafe { sys::gsl_fft_real_radix2_transform( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data)) }; result_handler!(ret, ()) } /// This function computes the inverse or backwards in-place radix-2 FFT of length n and stride stride on the half-complex sequence data /// stored according the output scheme used by gsl_fft_real_radix2. The result is a real array stored in natural order. #[doc(alias = "gsl_fft_halfcomplex_radix2_inverse")] - pub fn inverse(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_halfcomplex_radix2_inverse(data.as_mut_ptr(), stride, n) }; + pub fn inverse(data: &mut V) -> Result<(), Value> + where V: VectorMut + ?Sized { + let ret = unsafe { sys::gsl_fft_halfcomplex_radix2_inverse( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data)) }; result_handler!(ret, ()) } /// This function computes the inverse or backwards in-place radix-2 FFT of length n and stride stride on the half-complex sequence data /// stored according the output scheme used by gsl_fft_real_radix2. The result is a real array stored in natural order. #[doc(alias = "gsl_fft_halfcomplex_radix2_backward")] - pub fn backward(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_halfcomplex_radix2_backward(data.as_mut_ptr(), stride, n) }; + pub fn backward(data: &mut V) -> Result<(), Value> + where V: VectorMut + ?Sized { + let ret = unsafe { sys::gsl_fft_halfcomplex_radix2_backward( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data)) }; result_handler!(ret, ()) } @@ -322,18 +363,19 @@ pub mod real_radix2 { /// } /// ``` #[doc(alias = "gsl_fft_halfcomplex_radix2_unpack")] - pub fn unpack( - halfcomplex_coefficient: &mut [f64], - complex_coefficient: &mut [f64], - stride: usize, - n: usize, - ) -> Result<(), Value> { + pub fn unpack( + halfcomplex_coefficient: &V1, + complex_coefficient: &mut V2, // FIXME: Complex + ) -> Result<(), Value> + where V1: Vector + ?Sized, + V2: VectorMut + ?Sized { + check_equal_len(halfcomplex_coefficient, halfcomplex_coefficient)?; let ret = unsafe { sys::gsl_fft_halfcomplex_radix2_unpack( - halfcomplex_coefficient.as_mut_ptr(), - complex_coefficient.as_mut_ptr(), - stride, - n, + V1::as_slice(halfcomplex_coefficient).as_ptr(), + V2::as_mut_slice(complex_coefficient).as_mut_ptr(), + V1::stride(halfcomplex_coefficient), + V1::len(halfcomplex_coefficient), ) }; result_handler!(ret, ()) diff --git a/src/types/fast_fourier_transforms.rs b/src/types/fast_fourier_transforms.rs index fe61d853..8593d360 100644 --- a/src/types/fast_fourier_transforms.rs +++ b/src/types/fast_fourier_transforms.rs @@ -82,11 +82,15 @@ impl $complex_rust_name { data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { + if V::len(data) % 2 == 1 { + panic!("{}: the length of the data must be even", + stringify!($complex_rust_name::forward)); + } let ret = unsafe { sys::[<$name $($extra)? _forward>]( V::as_mut_slice(data).as_mut_ptr(), V::stride(data), - V::len(data), + V::len(data) / 2, // FIXME: use complex vectors? wavetable.unwrap_shared(), self.unwrap_unique(), ) @@ -101,11 +105,15 @@ impl $complex_rust_name { wavetable: &$rust_name, sign: crate::FftDirection, ) -> Result<(), Value> { + if V::len(data) % 2 == 1 { + panic!("{}: the length of the data must be even", + stringify!($complex_rust_name::transform)); + } let ret = unsafe { sys::[<$name $($extra)? _transform>]( V::as_mut_slice(data).as_mut_ptr(), V::stride(data), - V::len(data), + V::len(data) / 2, // FIXME: use complex vectors? wavetable.unwrap_shared(), self.unwrap_unique(), sign.into(), @@ -120,11 +128,15 @@ impl $complex_rust_name { data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { + if V::len(data) % 2 == 1 { + panic!("{}: the length of the data must be even", + stringify!($complex_rust_name::backward)); + } let ret = unsafe { sys::[<$name $($extra)? _backward>]( V::as_mut_slice(data).as_mut_ptr(), V::stride(data), - V::len(data), + V::len(data) / 2, // FIXME: use complex vectors? wavetable.unwrap_shared(), self.unwrap_unique(), ) @@ -138,11 +150,15 @@ impl $complex_rust_name { data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { + if V::len(data) % 2 == 1 { + panic!("{}: the length of the data must be even", + stringify!($complex_rust_name::inverse)); + } let ret = unsafe { sys::[<$name $($extra)? _inverse>]( V::as_mut_slice(data).as_mut_ptr(), V::stride(data), - V::len(data), + V::len(data) / 2, // FIXME: use complex vectors? wavetable.unwrap_shared(), self.unwrap_unique(), ) From dc8938eb570c496d9a0feb12385e07d7932d40c9 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 00:20:28 +0100 Subject: [PATCH 19/28] Act on Clippy warnings --- src/stats.rs | 19 +++++++++---------- src/types/mod.rs | 1 + src/types/monte_carlo.rs | 7 +++---- src/types/multimin.rs | 1 - src/types/multiroot.rs | 15 +++++++++------ src/types/n_tuples.rs | 4 ++-- src/types/ordinary_differential_equations.rs | 2 +- src/types/roots.rs | 2 +- src/types/vector.rs | 15 ++++++++++++--- src/utilities.rs | 2 +- src/view.rs | 4 ++-- 11 files changed, 41 insertions(+), 31 deletions(-) diff --git a/src/stats.rs b/src/stats.rs index 5a254db0..544b307a 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -1,6 +1,15 @@ // // A rust binding for the GSL library by Guillaume Gomez (guillaume1.gomez@gmail.com) // +//! # Weighted Samples +//! +//! The functions described in this section allow the computation of +//! statistics for weighted samples. The functions accept a vector of +//! samples, xᵢ, with associated weights, wᵢ. Each sample xᵢ is +//! considered as having been drawn from a Gaussian distribution with +//! variance σᵢ². The sample weight wᵢ is defined as the reciprocal +//! of this variance, wᵢ = 1/σᵢ². Setting a weight to zero +//! corresponds to removing a sample from a dataset. use crate::vector::{self, Vector}; @@ -9,16 +18,6 @@ use crate::vector::VectorMut; // FIXME: Many functions are missing. -/// # Weighted Samples -/// -/// The functions described in this section allow the computation of -/// statistics for weighted samples. The functions accept a vector of -/// samples, xᵢ, with associated weights, wᵢ. Each sample xᵢ is -/// considered as having been drawn from a Gaussian distribution with -/// variance σᵢ². The sample weight wᵢ is defined as the reciprocal -/// of this variance, wᵢ = 1/σᵢ². Setting a weight to zero -/// corresponds to removing a sample from a dataset. - /// Return the weighted mean of the dataset `data` using the set of /// weights `w`. The weighted mean is defined as, /// ̂μ = (∑ wᵢ xᵢ) / (∑ wᵢ). diff --git a/src/types/mod.rs b/src/types/mod.rs index e14dfbcc..8ba199df 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -8,6 +8,7 @@ pub use self::basis_spline::BSpLineWorkspace; pub use self::chebyshev::ChebSeries; pub use self::combination::Combination; +#[allow(deprecated)] pub use self::complex::{ComplexF32, ComplexF64, ComplexOps}; pub use self::discrete_hankel::DiscreteHankel; pub use self::eigen_symmetric_workspace::{ diff --git a/src/types/monte_carlo.rs b/src/types/monte_carlo.rs index 11b9c1ba..aff5ae41 100644 --- a/src/types/monte_carlo.rs +++ b/src/types/monte_carlo.rs @@ -76,7 +76,6 @@ The estimates are averaged using the arithmetic mean, but no error is computed. use crate::ffi::FFI; use crate::Value; use std::marker::PhantomData; -use std::mem::transmute; use std::os::raw::c_void; use std::slice; @@ -147,7 +146,7 @@ impl PlainMonteCarlo { let f: Box = Box::new(f); let ret = unsafe { let func = sys::gsl_monte_function { - f: transmute(monte_trampoline:: as usize), + f: Some(monte_trampoline::), dim: xl.len() as _, params: Box::into_raw(f) as *mut _, }; @@ -247,7 +246,7 @@ impl MiserMonteCarlo { let f: Box = Box::new(f); let ret = unsafe { let mut func = sys::gsl_monte_function { - f: transmute(monte_trampoline:: as usize), + f: Some(monte_trampoline::), dim: xl.len() as _, params: Box::into_raw(f) as *mut _, }; @@ -409,7 +408,7 @@ impl VegasMonteCarlo { let f: Box = Box::new(f); let ret = unsafe { let mut func = sys::gsl_monte_function { - f: transmute(monte_trampoline:: as usize), + f: Some(monte_trampoline::), dim: xl.len() as _, params: Box::into_raw(f) as *mut _, }; diff --git a/src/types/multimin.rs b/src/types/multimin.rs index 064334e0..99ded955 100644 --- a/src/types/multimin.rs +++ b/src/types/multimin.rs @@ -262,7 +262,6 @@ impl<'a> MultiMinFdfFunction<'a> { } #[allow(clippy::wrong_self_convention)] - fn to_raw(&mut self) -> *mut sys::gsl_multimin_function_fdf { self.intern.n = self.n; self.intern.params = self as *mut MultiMinFdfFunction as *mut c_void; diff --git a/src/types/multiroot.rs b/src/types/multiroot.rs index 3843f299..3878b80c 100644 --- a/src/types/multiroot.rs +++ b/src/types/multiroot.rs @@ -440,16 +440,19 @@ impl MultiRootFdfSolver { std::str::from_utf8(slice).ok().map(|x| x.to_owned()) } - /// Perform a single iteration of the solver. If the iteration encounters an - /// unexpected problem then an error code will be returned, + /// Perform a single iteration of the solver. If the iteration + /// encounters an unexpected problem then an error code will be + /// returned, /// - /// * `crate::Value::BadFunc` the iteration encountered a singular point where the function or its derivative evaluated to Inf or NaN. + /// * `crate::Value::BadFunc` the iteration encountered a singular + /// point where the function or its derivative evaluated to Inf or NaN. /// /// * `crate::Value::NoProgress` the iteration is not making any progress, - /// preventing the algorithm from continuing. + /// preventing the algorithm from continuing. /// - /// The solver maintains a current best estimate of the root and its function value at all times. - /// This information can be accessed with `root`, `f`, and `dx` functions. + /// The solver maintains a current best estimate of the root and + /// its function value at all times. This information can be + /// accessed with `root`, `f`, and `dx` functions. #[doc(alias = "gsl_multiroot_fdfsolver_iterate")] pub fn iterate(&mut self) -> Result<(), Value> { let ret = unsafe { sys::gsl_multiroot_fdfsolver_iterate(self.unwrap_unique()) }; diff --git a/src/types/n_tuples.rs b/src/types/n_tuples.rs index f3da6b2c..ef8eb039 100644 --- a/src/types/n_tuples.rs +++ b/src/types/n_tuples.rs @@ -176,12 +176,12 @@ macro_rules! impl_project { let f: Box = Box::new(value_func); let mut value_function = sys::gsl_ntuple_value_fn { - function: unsafe { std::mem::transmute(value_trampoline:: as usize) }, + function: Some(value_trampoline::), params: Box::into_raw(f) as *mut _, }; let f: Box = Box::new(select_func); let mut select_function = sys::gsl_ntuple_select_fn { - function: unsafe { std::mem::transmute(select_trampoline:: as usize) }, + function: Some(select_trampoline::), params: Box::into_raw(f) as *mut _, }; let ret = unsafe { diff --git a/src/types/ordinary_differential_equations.rs b/src/types/ordinary_differential_equations.rs index 1573bf1b..9621e3df 100644 --- a/src/types/ordinary_differential_equations.rs +++ b/src/types/ordinary_differential_equations.rs @@ -880,7 +880,7 @@ impl<'a> ODEiv2Driver<'a> { } } -impl<'a> Drop for ODEiv2Driver<'a> { +impl Drop for ODEiv2Driver<'_> { #[doc(alias = "gsl_odeiv2_driver_free")] fn drop(&mut self) { unsafe { sys::gsl_odeiv2_driver_free(self.d) }; diff --git a/src/types/roots.rs b/src/types/roots.rs index 87d1a885..02dba6bd 100644 --- a/src/types/roots.rs +++ b/src/types/roots.rs @@ -92,7 +92,7 @@ impl RootFSolverType { /// The Brent-Dekker method (referred to here as Brent’s method) combines an interpo- /// lation strategy with the bisection algorithm. This produces a fast algorithm which is /// still robust. - + /// /// On each iteration Brent’s method approximates the function using an interpolating /// curve. On the first iteration this is a linear interpolation of the two endpoints. For /// subsequent iterations the algorithm uses an inverse quadratic fit to the last three diff --git a/src/types/vector.rs b/src/types/vector.rs index 4255c1b4..120e2aeb 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -103,6 +103,15 @@ pub unsafe trait Vector { } } +/// Trait implemented by types that are considered *mutable* vectors +/// by this crate. Elements of the vector are of type `F` (`f32` or `f64`). +/// +/// Bring this trait into scope in order to add methods to specify +/// strides to the types implementing `Vector`. +/// +/// # Safety +/// One must make sore that `(len - 1) * stride` does not exceed the +/// length of the underlying slice. pub unsafe trait VectorMut: Vector { /// Same as [`Vector::as_slice`] but mutable. fn as_mut_slice(x: &mut Self) -> &mut [F]; @@ -145,7 +154,7 @@ pub struct SliceMut<'a, F> { stride: usize, } -unsafe impl<'a, F> Vector for Slice<'a, F> { +unsafe impl Vector for Slice<'_, F> { #[inline] fn len(x: &Self) -> usize { x.len @@ -160,7 +169,7 @@ unsafe impl<'a, F> Vector for Slice<'a, F> { } } -unsafe impl<'a, F> Vector for SliceMut<'a, F> { +unsafe impl Vector for SliceMut<'_, F> { #[inline] fn len(x: &Self) -> usize { x.len @@ -175,7 +184,7 @@ unsafe impl<'a, F> Vector for SliceMut<'a, F> { } } -unsafe impl<'a, F> VectorMut for SliceMut<'a, F> { +unsafe impl VectorMut for SliceMut<'_, F> { #[inline] fn as_mut_slice(x: &mut Self) -> &mut [F] { x.vec diff --git a/src/utilities.rs b/src/utilities.rs index 95f748d4..95bd6b4c 100644 --- a/src/utilities.rs +++ b/src/utilities.rs @@ -27,7 +27,7 @@ impl IOStream { /// Open a file in write mode. pub fn fwrite_handle>(file: &P) -> io::Result { let path = CString::new(file.as_ref().to_str().unwrap()).unwrap(); - let ptr = unsafe { fopen(path.as_ptr(), b"w\0".as_ptr() as *const c_char) }; + let ptr = unsafe { fopen(path.as_ptr(), c"w".as_ptr() as *const c_char) }; if ptr.is_null() { return Err(io::Error::new( io::ErrorKind::Other, diff --git a/src/view.rs b/src/view.rs index 490d7a0a..3e262bb5 100644 --- a/src/view.rs +++ b/src/view.rs @@ -12,7 +12,7 @@ pub struct View<'a, T> { phantom: PhantomData<&'a ()>, } -impl<'a, T> View<'a, T> { +impl View<'_, T> { pub(crate) fn new

(inner: *mut P) -> Self where T: FFI

, @@ -24,7 +24,7 @@ impl<'a, T> View<'a, T> { } } -impl<'a, T> Deref for View<'a, T> { +impl Deref for View<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { From 07a88c37b7a1b0f2b2ab721322ad7e2cf2bcd92b Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 00:26:34 +0100 Subject: [PATCH 20/28] Allow to keep the error handler in a global static var (no warnings) --- src/error.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/error.rs b/src/error.rs index 576aa12d..dc46f115 100644 --- a/src/error.rs +++ b/src/error.rs @@ -156,6 +156,7 @@ pub fn str_error(error: crate::Value) -> &'static str { } } +// FIXME: Can do better? static mut CALLBACK: Option = None; /// `f` is the type of GSL error handler functions. An error handler will be passed four arguments @@ -197,6 +198,7 @@ static mut CALLBACK: Option = None; /// let old_handler = set_error_handler(None); /// ``` #[doc(alias = "gsl_set_error_handler")] +#[allow(static_mut_refs)] pub fn set_error_handler( f: Option, ) -> Option { @@ -220,6 +222,7 @@ pub fn set_error_handler( /// routines must be checked. This is the recommended behavior for production programs. The previous /// handler is returned (so that you can restore it later). #[doc(alias = "gsl_set_error_handler_off")] +#[allow(static_mut_refs)] pub fn set_error_handler_off() -> Option { unsafe { sys::gsl_set_error_handler_off(); From 57359918f8b62ae81ca2de4707a0b5b384c86e06 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 00:29:13 +0100 Subject: [PATCH 21/28] Run cargo fmt --- src/blas.rs | 88 +++------------- src/fft.rs | 140 ++++++++++++++++--------- src/linear_algebra.rs | 29 +---- src/polynomials.rs | 27 +---- src/types/complex.rs | 73 ++++++------- src/types/eigen_symmetric_workspace.rs | 6 +- src/types/fast_fourier_transforms.rs | 5 +- src/types/matrix_complex.rs | 2 +- src/types/vector_complex.rs | 4 +- 9 files changed, 163 insertions(+), 211 deletions(-) diff --git a/src/blas.rs b/src/blas.rs index 9e612f39..cd5da230 100644 --- a/src/blas.rs +++ b/src/blas.rs @@ -4,8 +4,12 @@ pub mod level1 { use crate::ffi::FFI; + use crate::{ + types, + types::complex::{FromC, ToC}, + Value, + }; use crate::{VectorF32, VectorF64}; - use crate::{types, types::complex::{ToC, FromC}, Value}; use num_complex::Complex; /// This function computes the sum \alpha + x^T y for the vectors x and y, returning the result @@ -283,13 +287,8 @@ pub mod level1 { x: &types::VectorComplexF32, y: &mut types::VectorComplexF32, ) -> Result<(), Value> { - let ret = unsafe { - sys::gsl_blas_caxpy( - alpha.unwrap(), - x.unwrap_shared(), - y.unwrap_unique(), - ) - }; + let ret = + unsafe { sys::gsl_blas_caxpy(alpha.unwrap(), x.unwrap_shared(), y.unwrap_unique()) }; result_handler!(ret, ()) } @@ -300,13 +299,8 @@ pub mod level1 { x: &types::VectorComplexF64, y: &mut types::VectorComplexF64, ) -> Result<(), Value> { - let ret = unsafe { - sys::gsl_blas_zaxpy( - alpha.unwrap(), - x.unwrap_shared(), - y.unwrap_unique(), - ) - }; + let ret = + unsafe { sys::gsl_blas_zaxpy(alpha.unwrap(), x.unwrap_shared(), y.unwrap_unique()) }; result_handler!(ret, ()) } @@ -358,14 +352,7 @@ pub mod level1 { pub fn srotg(mut a: f32, mut b: f32) -> Result<(f32, f32, f32), Value> { let mut c = 0.; let mut s = 0.; - let ret = unsafe { - sys::gsl_blas_srotg( - &mut a, - &mut b, - &mut c, - &mut s, - ) - }; + let ret = unsafe { sys::gsl_blas_srotg(&mut a, &mut b, &mut c, &mut s) }; result_handler!(ret, (c, s, a)) } @@ -381,14 +368,7 @@ pub mod level1 { pub fn drotg(mut a: f64, mut b: f64) -> Result<(f64, f64, f64), Value> { let mut c = 0.; let mut s = 0.; - let ret = unsafe { - sys::gsl_blas_drotg( - &mut a, - &mut b, - &mut c, - &mut s, - ) - }; + let ret = unsafe { sys::gsl_blas_drotg(&mut a, &mut b, &mut c, &mut s) }; result_handler!(ret, (c, s, a)) } @@ -420,22 +400,9 @@ pub mod level1 { /// The modified Givens transformation is defined in the original /// [Level-1 BLAS specification](https://help.imsl.com/fortran/fnlmath/current/basic-linear-algebra-sub.htm#mch9_1817247609_srotmg). #[doc(alias = "gsl_blas_srotmg")] - pub fn srotmg( - mut d1: f32, - mut d2: f32, - mut b1: f32, - b2: f32, - ) -> Result<[f32; 5], Value> { + pub fn srotmg(mut d1: f32, mut d2: f32, mut b1: f32, b2: f32) -> Result<[f32; 5], Value> { let mut p = [f32::NAN; 5]; - let ret = unsafe { - sys::gsl_blas_srotmg( - &mut d1, - &mut d2, - &mut b1, - b2, - p.as_mut_ptr(), - ) - }; + let ret = unsafe { sys::gsl_blas_srotmg(&mut d1, &mut d2, &mut b1, b2, p.as_mut_ptr()) }; result_handler!(ret, p) } @@ -443,22 +410,9 @@ pub mod level1 { /// The modified Givens transformation is defined in the original /// [Level-1 BLAS specification](https://help.imsl.com/fortran/fnlmath/current/basic-linear-algebra-sub.htm#mch9_1817247609_srotmg). #[doc(alias = "gsl_blas_drotmg")] - pub fn drotmg( - mut d1: f64, - mut d2: f64, - mut b1: f64, - b2: f64, - ) -> Result<[f64; 5], Value> { + pub fn drotmg(mut d1: f64, mut d2: f64, mut b1: f64, b2: f64) -> Result<[f64; 5], Value> { let mut p = [f64::NAN; 5]; - let ret = unsafe { - sys::gsl_blas_drotmg( - &mut d1, - &mut d2, - &mut b1, - b2, - p.as_mut_ptr(), - ) - }; + let ret = unsafe { sys::gsl_blas_drotmg(&mut d1, &mut d2, &mut b1, b2, p.as_mut_ptr()) }; result_handler!(ret, p) } @@ -474,11 +428,7 @@ pub mod level1 { if lenx != leny { panic!("rgsl::blas::srotm: len(x) = {lenx} != len(y) = {leny}") } - let ret = - unsafe { sys::gsl_blas_srotm( - x.unwrap_unique(), - y.unwrap_unique(), - p.as_ptr()) }; + let ret = unsafe { sys::gsl_blas_srotm(x.unwrap_unique(), y.unwrap_unique(), p.as_ptr()) }; result_handler!(ret, ()) } @@ -494,11 +444,7 @@ pub mod level1 { if lenx != leny { panic!("rgsl::blas::drotm: len(x) = {lenx} != len(y) = {leny}") } - let ret = - unsafe { sys::gsl_blas_drotm( - x.unwrap_unique(), - y.unwrap_unique(), - p.as_ptr()) }; + let ret = unsafe { sys::gsl_blas_drotm(x.unwrap_unique(), y.unwrap_unique(), p.as_ptr()) }; result_handler!(ret, ()) } diff --git a/src/fft.rs b/src/fft.rs index f59f57d5..2cb820db 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -165,10 +165,7 @@ is desirable for better locality of memory accesses). /// /// The functions return a value of crate::Value::Success if no errors were detected, or Value::Dom if the length n is not a power of two. pub mod radix2 { - use crate::{ - Value, - vector::VectorMut, - }; + use crate::{vector::VectorMut, Value}; #[doc(alias = "gsl_fft_complex_radix2_forward")] pub fn forward(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { @@ -186,39 +183,55 @@ pub mod radix2 { V::as_mut_slice(data).as_mut_ptr(), V::stride(data), V::len(data), - sign.into()) + sign.into(), + ) }; result_handler!(ret, ()) } #[doc(alias = "gsl_fft_complex_radix2_backward")] pub fn backward(data: &mut V) -> Result<(), Value> - where V: VectorMut + ?Sized { - let ret = unsafe { sys::gsl_fft_complex_radix2_backward( - V::as_mut_slice(data).as_mut_ptr(), - V::stride(data), - V::len(data)) }; + where + V: VectorMut + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_complex_radix2_backward( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } #[doc(alias = "gsl_fft_complex_radix2_inverse")] pub fn inverse(data: &mut V) -> Result<(), Value> - where V: VectorMut + ?Sized { - let ret = unsafe { sys::gsl_fft_complex_radix2_inverse( - V::as_mut_slice(data).as_mut_ptr(), - V::stride(data), - V::len(data)) }; + where + V: VectorMut + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_complex_radix2_inverse( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } /// This is decimation-in-frequency version of the radix-2 FFT function. #[doc(alias = "gsl_fft_complex_radix2_dif_forward")] pub fn dif_forward(data: &mut V) -> Result<(), Value> - where V: VectorMut + ?Sized { - let ret = unsafe { sys::gsl_fft_complex_radix2_dif_forward( - V::as_mut_slice(data).as_mut_ptr(), - V::stride(data), - V::len(data)) }; + where + V: VectorMut + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_complex_radix2_dif_forward( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } @@ -233,29 +246,41 @@ pub mod radix2 { V::as_mut_slice(data).as_mut_ptr(), V::stride(data), V::len(data), - sign.into()) }; + sign.into(), + ) + }; result_handler!(ret, ()) } /// This is decimation-in-frequency version of the radix-2 FFT function. #[doc(alias = "gsl_fft_complex_radix2_dif_backward")] pub fn dif_backward(data: &mut V) -> Result<(), Value> - where V: VectorMut + ?Sized { - let ret = unsafe { sys::gsl_fft_complex_radix2_dif_backward( - V::as_mut_slice(data).as_mut_ptr(), - V::stride(data), - V::len(data)) }; + where + V: VectorMut + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_complex_radix2_dif_backward( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } /// This is decimation-in-frequency version of the radix-2 FFT function. #[doc(alias = "gsl_fft_complex_radix2_dif_inverse")] pub fn dif_inverse(data: &mut V) -> Result<(), Value> - where V: VectorMut + ?Sized { - let ret = unsafe { sys::gsl_fft_complex_radix2_dif_inverse( - V::as_mut_slice(data).as_mut_ptr(), - V::stride(data), - V::len(data)) }; + where + V: VectorMut + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_complex_radix2_dif_inverse( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } } @@ -264,8 +289,8 @@ pub mod radix2 { /// are a power of 2. pub mod real_radix2 { use crate::{ - Value, vector::{check_equal_len, Vector, VectorMut}, + Value, }; /// This function computes an in-place radix-2 FFT of length n and stride stride on the real array data. The output is a half-complex sequence, @@ -300,11 +325,16 @@ pub mod real_radix2 { /// below. #[doc(alias = "gsl_fft_real_radix2_transform")] pub fn transform(data: &mut V) -> Result<(), Value> - where V: VectorMut + ?Sized { - let ret = unsafe { sys::gsl_fft_real_radix2_transform( - V::as_mut_slice(data).as_mut_ptr(), - V::stride(data), - V::len(data)) }; + where + V: VectorMut + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_real_radix2_transform( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } @@ -312,11 +342,16 @@ pub mod real_radix2 { /// stored according the output scheme used by gsl_fft_real_radix2. The result is a real array stored in natural order. #[doc(alias = "gsl_fft_halfcomplex_radix2_inverse")] pub fn inverse(data: &mut V) -> Result<(), Value> - where V: VectorMut + ?Sized { - let ret = unsafe { sys::gsl_fft_halfcomplex_radix2_inverse( - V::as_mut_slice(data).as_mut_ptr(), - V::stride(data), - V::len(data)) }; + where + V: VectorMut + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_halfcomplex_radix2_inverse( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } @@ -324,11 +359,16 @@ pub mod real_radix2 { /// stored according the output scheme used by gsl_fft_real_radix2. The result is a real array stored in natural order. #[doc(alias = "gsl_fft_halfcomplex_radix2_backward")] pub fn backward(data: &mut V) -> Result<(), Value> - where V: VectorMut + ?Sized { - let ret = unsafe { sys::gsl_fft_halfcomplex_radix2_backward( - V::as_mut_slice(data).as_mut_ptr(), - V::stride(data), - V::len(data)) }; + where + V: VectorMut + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_halfcomplex_radix2_backward( + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } @@ -367,8 +407,10 @@ pub mod real_radix2 { halfcomplex_coefficient: &V1, complex_coefficient: &mut V2, // FIXME: Complex ) -> Result<(), Value> - where V1: Vector + ?Sized, - V2: VectorMut + ?Sized { + where + V1: Vector + ?Sized, + V2: VectorMut + ?Sized, + { check_equal_len(halfcomplex_coefficient, halfcomplex_coefficient)?; let ret = unsafe { sys::gsl_fft_halfcomplex_radix2_unpack( diff --git a/src/linear_algebra.rs b/src/linear_algebra.rs index 2646f40e..eea0f12e 100644 --- a/src/linear_algebra.rs +++ b/src/linear_algebra.rs @@ -164,10 +164,7 @@ James Demmel, Krešimir Veselić, “Jacobi’s Method is more accurate than QR use crate::complex::ToC; use crate::enums; use crate::ffi::FFI; -use crate::{ - complex::FromC, - Value, -}; +use crate::{complex::FromC, Value}; use num_complex::Complex; /// Factorise a general N x N matrix A into, @@ -1275,11 +1272,7 @@ pub fn householder_transform(v: &mut crate::VectorF64) -> f64 { /// the first. On output the transformation is stored in the vector v and the scalar \tau is returned. #[doc(alias = "gsl_linalg_complex_householder_transform")] pub fn complex_householder_transform(v: &mut crate::VectorComplexF64) -> Complex { - unsafe { - sys::gsl_linalg_complex_householder_transform( - v.unwrap_unique(), - ).wrap() - } + unsafe { sys::gsl_linalg_complex_householder_transform(v.unwrap_unique()).wrap() } } /// This function applies the Householder matrix P defined by the scalar tau and the vector v to the left-hand side of the matrix A. On output @@ -1303,11 +1296,7 @@ pub fn complex_householder_hm( a: &mut crate::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { - sys::gsl_linalg_complex_householder_hm( - tau.unwrap(), - v.unwrap_shared(), - a.unwrap_unique(), - ) + sys::gsl_linalg_complex_householder_hm(tau.unwrap(), v.unwrap_shared(), a.unwrap_unique()) }; result_handler!(ret, ()) } @@ -1333,11 +1322,7 @@ pub fn complex_householder_mh( a: &mut crate::MatrixComplexF64, ) -> Result<(), Value> { let ret = unsafe { - sys::gsl_linalg_complex_householder_mh( - tau.unwrap(), - v.unwrap_shared(), - a.unwrap_unique(), - ) + sys::gsl_linalg_complex_householder_mh(tau.unwrap(), v.unwrap_shared(), a.unwrap_unique()) }; result_handler!(ret, ()) } @@ -1363,11 +1348,7 @@ pub fn complex_householder_hv( w: &mut crate::VectorComplexF64, ) -> Result<(), Value> { let ret = unsafe { - sys::gsl_linalg_complex_householder_hv( - tau.unwrap(), - v.unwrap_shared(), - w.unwrap_unique(), - ) + sys::gsl_linalg_complex_householder_hv(tau.unwrap(), v.unwrap_shared(), w.unwrap_unique()) }; result_handler!(ret, ()) } diff --git a/src/polynomials.rs b/src/polynomials.rs index d43e0837..bc90d969 100644 --- a/src/polynomials.rs +++ b/src/polynomials.rs @@ -27,7 +27,7 @@ R. L. Burden and J. D. Faires, Numerical Analysis, 9th edition, ISBN 0-538-73351 /// stability. pub mod evaluation { use crate::{ - types::complex::{ToC, FromC}, + types::complex::{FromC, ToC}, Value, }; use num_complex::Complex; @@ -41,13 +41,7 @@ pub mod evaluation { /// This function evaluates a polynomial with real coefficients for the complex variable z. #[doc(alias = "gsl_poly_complex_eval")] pub fn poly_complex_eval(c: &[f64], z: &Complex) -> Complex { - unsafe { - sys::gsl_poly_complex_eval( - c.as_ptr(), - c.len() as i32, - z.unwrap(), - ).wrap() - } + unsafe { sys::gsl_poly_complex_eval(c.as_ptr(), c.len() as i32, z.unwrap()).wrap() } } /// This function evaluates a polynomial with complex coefficients for the complex variable z. @@ -60,11 +54,7 @@ pub mod evaluation { tmp.push(it.unwrap()) } unsafe { - sys::gsl_complex_poly_complex_eval( - tmp.as_ptr(), - tmp.len() as i32, - z.unwrap(), - ).wrap() + sys::gsl_complex_poly_complex_eval(tmp.as_ptr(), tmp.len() as i32, z.unwrap()).wrap() } } @@ -221,10 +211,7 @@ pub mod quadratic_equations { z1: &mut Complex, ) -> Result<(), Value> { let ret = - unsafe { sys::gsl_poly_complex_solve_quadratic( - a, b, c, - transmute(z0), - transmute(z1)) }; + unsafe { sys::gsl_poly_complex_solve_quadratic(a, b, c, transmute(z0), transmute(z1)) }; result_handler!(ret, ()) } } @@ -274,11 +261,7 @@ pub mod cubic_equations { z2: &mut Complex, ) -> Result<(), Value> { let ret = unsafe { - sys::gsl_poly_complex_solve_cubic( - a, b, c, - transmute(z0), - transmute(z1), - transmute(z2)) + sys::gsl_poly_complex_solve_cubic(a, b, c, transmute(z0), transmute(z1), transmute(z2)) }; result_handler!(ret, ()) } diff --git a/src/types/complex.rs b/src/types/complex.rs index ecc2d208..06cded69 100644 --- a/src/types/complex.rs +++ b/src/types/complex.rs @@ -4,9 +4,9 @@ use num_complex::Complex; -#[deprecated(since="8.0.0", note="use `Complex` instead")] +#[deprecated(since = "8.0.0", note = "use `Complex` instead")] pub type ComplexF64 = Complex; -#[deprecated(since="8.0.0", note="use `Complex` instead")] +#[deprecated(since = "8.0.0", note = "use `Complex` instead")] pub type ComplexF32 = Complex; pub(crate) trait ToC { @@ -44,13 +44,13 @@ pub trait ComplexOps { /// This function returns the magnitude of the complex number z, |z|. #[doc(alias = "gsl_complex_abs")] - #[deprecated(since="8.0.0", note="please use `.norm()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.norm()` instead")] fn abs(&self) -> T; /// This function returns the squared magnitude of the complex /// number z = `self`, |z|². #[doc(alias = "gsl_complex_abs2")] - #[deprecated(since="8.0.0", note="please use `.norm_sqr()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.norm_sqr()` instead")] fn abs2(&self) -> T; /// This function returns the natural logarithm of the magnitude @@ -65,90 +65,90 @@ pub trait ComplexOps { /// This function returns the sum of the complex numbers a and b, z=a+b. #[doc(alias = "gsl_complex_add")] - #[deprecated(since="8.0.0", note="please use `+` instead")] + #[deprecated(since = "8.0.0", note = "please use `+` instead")] fn add(&self, other: &Complex) -> Complex; /// This function returns the difference of the complex numbers a /// and b, z=a-b. #[doc(alias = "gsl_complex_sub")] - #[deprecated(since="8.0.0", note="please use `-` instead")] + #[deprecated(since = "8.0.0", note = "please use `-` instead")] fn sub(&self, other: &Complex) -> Complex; /// This function returns the product of the complex numbers a and b, z=ab. #[doc(alias = "gsl_complex_mul")] - #[deprecated(since="8.0.0", note="please use `*` instead")] + #[deprecated(since = "8.0.0", note = "please use `*` instead")] fn mul(&self, other: &Complex) -> Complex; /// This function returns the quotient of the complex numbers a /// and b, z=a/b. #[doc(alias = "gsl_complex_div")] - #[deprecated(since="8.0.0", note="please use `/` of `fdiv` instead")] + #[deprecated(since = "8.0.0", note = "please use `/` of `fdiv` instead")] fn div(&self, other: &Complex) -> Complex; /// This function returns the sum of the complex number a and the /// real number x, z = a + x. #[doc(alias = "gsl_complex_add_real")] - #[deprecated(since="8.0.0", note="please use `+` instead")] + #[deprecated(since = "8.0.0", note = "please use `+` instead")] fn add_real(&self, x: T) -> Complex; /// This function returns the difference of the complex number a /// and the real number x, z=a-x. #[doc(alias = "gsl_complex_sub_real")] - #[deprecated(since="8.0.0", note="please use `-` instead")] + #[deprecated(since = "8.0.0", note = "please use `-` instead")] fn sub_real(&self, x: T) -> Complex; /// This function returns the product of the complex number a and /// the real number x, z=ax. #[doc(alias = "gsl_complex_mul_real")] - #[deprecated(since="8.0.0", note="please use `*` instead")] + #[deprecated(since = "8.0.0", note = "please use `*` instead")] fn mul_real(&self, x: T) -> Complex; /// This function returns the quotient of the complex number a and /// the real number x, z=a/x. #[doc(alias = "gsl_complex_div_real")] - #[deprecated(since="8.0.0", note="please use `/` instead")] + #[deprecated(since = "8.0.0", note = "please use `/` instead")] fn div_real(&self, x: T) -> Complex; /// This function returns the sum of the complex number a and the /// imaginary number iy, z=a+iy. #[doc(alias = "gsl_complex_add_imag")] - #[deprecated(since="8.0.0", note="please use `self + x * Complex::I` instead")] + #[deprecated(since = "8.0.0", note = "please use `self + x * Complex::I` instead")] fn add_imag(&self, x: T) -> Complex; /// This function returns the difference of the complex number a /// and the imaginary number iy, z=a-iy. #[doc(alias = "gsl_complex_sub_imag")] - #[deprecated(since="8.0.0", note="please use `self - x * Complex::I` instead")] + #[deprecated(since = "8.0.0", note = "please use `self - x * Complex::I` instead")] fn sub_imag(&self, x: T) -> Complex; /// This function returns the product of the complex number a and /// the imaginary number iy, z=a*(iy). #[doc(alias = "gsl_complex_mul_imag")] - #[deprecated(since="8.0.0", note="please use `self * x * Complex::I` instead")] + #[deprecated(since = "8.0.0", note = "please use `self * x * Complex::I` instead")] fn mul_imag(&self, x: T) -> Complex; /// This function returns the quotient of the complex number a and /// the imaginary number iy, z=a/(iy). #[doc(alias = "gsl_complex_div_imag")] - #[deprecated(since="8.0.0", note="please use `self / (x * Complex::I)` instead")] + #[deprecated(since = "8.0.0", note = "please use `self / (x * Complex::I)` instead")] fn div_imag(&self, x: T) -> Complex; /// This function returns the complex conjugate of the complex /// number z, z^* = x - i y. #[doc(alias = "gsl_complex_conjugate")] - #[deprecated(since="8.0.0", note="please use `.conj()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.conj()` instead")] fn conjugate(&self) -> Complex; /// This function returns the inverse, or reciprocal, of the /// complex number z, 1/z = (x - i y)/ (x^2 + y^2). #[doc(alias = "gsl_complex_inverse")] - #[deprecated(since="8.0.0", note="please use `.inv()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.inv()` instead")] fn inverse(&self) -> Complex; /// This function returns the negative of the complex number z, -z /// = (-x) + i(-y). #[doc(alias = "gsl_complex_negative")] - #[deprecated(since="8.0.0", note="please use the unary `-` instead")] + #[deprecated(since = "8.0.0", note = "please use the unary `-` instead")] fn negative(&self) -> Complex; /// This function returns the complex square root of the real @@ -160,20 +160,20 @@ pub trait ComplexOps { /// complex power a z^a. This is computed as \exp(\log(z)*a) /// using complex logarithms and complex exponentials. #[doc(alias = "gsl_complex_pow")] - #[deprecated(since="8.0.0", note="please use the unary `-` instead")] + #[deprecated(since = "8.0.0", note = "please use the unary `-` instead")] fn pow(&self, other: &Complex) -> Complex; /// This function returns the complex number z raised to the real /// power x, z^x. #[doc(alias = "gsl_complex_pow_real")] - #[deprecated(since="8.0.0", note="please use `.powf(x)` instead")] + #[deprecated(since = "8.0.0", note = "please use `.powf(x)` instead")] fn pow_real(&self, x: T) -> Complex; /// This function returns the complex base-b logarithm of the /// complex number z, \log_b(z). This quantity is computed as the /// ratio \log(z)/\log(b). #[doc(alias = "gsl_complex_log_b")] - #[deprecated(since="8.0.0", note="please use `.log(base)` instead")] + #[deprecated(since = "8.0.0", note = "please use `.log(base)` instead")] fn log_b(&self, base: &Complex) -> Complex; /// This function returns the complex secant of the complex number @@ -195,7 +195,7 @@ pub trait ComplexOps { /// number z, \arcsin(z). The branch cuts are on the real axis, /// less than -1 and greater than 1. #[doc(alias = "gsl_complex_arcsin")] - #[deprecated(since="8.0.0", note="please use `.asin()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.asin()` instead")] fn arcsin(&self) -> Complex; /// This function returns the complex arcsine of the real number @@ -214,7 +214,7 @@ pub trait ComplexOps { /// number z, \arccos(z). The branch cuts are on the real axis, /// less than -1 and greater than 1. #[doc(alias = "gsl_complex_arccos")] - #[deprecated(since="8.0.0", note="please use `.acos()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.acos()` instead")] fn arccos(&self) -> Complex; /// This function returns the complex arccosine of the real number @@ -232,7 +232,7 @@ pub trait ComplexOps { /// number z, \arctan(z). The branch cuts are on the imaginary /// axis, below -i and above i. #[doc(alias = "gsl_complex_arctan")] - #[deprecated(since="8.0.0", note="please use `.atan()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.atan()` instead")] fn arctan(&self) -> Complex; /// This function returns the complex arcsecant of the complex @@ -279,7 +279,7 @@ pub trait ComplexOps { /// complex number z, \arcsinh(z). The branch cuts are on the /// imaginary axis, below -i and above i. #[doc(alias = "gsl_complex_arcsinh")] - #[deprecated(since="8.0.0", note="please use `.asinh()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.asinh()` instead")] fn arcsinh(&self) -> Complex; /// This function returns the complex hyperbolic arccosine of the @@ -288,7 +288,7 @@ pub trait ComplexOps { /// square root in formula 4.6.21 of Abramowitz & Stegun giving /// \arccosh(z)=\log(z-\sqrt{z^2-1}). #[doc(alias = "gsl_complex_arccosh")] - #[deprecated(since="8.0.0", note="please use `.acosh()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.acosh()` instead")] fn arccosh(&self) -> Complex; /// This function returns the complex hyperbolic arccosine of the @@ -301,7 +301,7 @@ pub trait ComplexOps { /// /// The branch cuts are on the real axis, less than -1 and greater than 1. #[doc(alias = "gsl_complex_arctanh")] - #[deprecated(since="8.0.0", note="please use `.atanh()` instead")] + #[deprecated(since = "8.0.0", note = "please use `.atanh()` instead")] fn arctanh(&self) -> Complex; /// This function returns the complex hyperbolic arctangent of the @@ -324,10 +324,10 @@ pub trait ComplexOps { #[doc(alias = "gsl_complex_arccoth")] fn arccoth(&self) -> Complex; - #[deprecated(since="8.0.0", note="please use `.re` instead")] + #[deprecated(since = "8.0.0", note = "please use `.re` instead")] fn real(&self) -> T; - #[deprecated(since="8.0.0", note="please use `.im` instead")] + #[deprecated(since = "8.0.0", note = "please use `.im` instead")] fn imaginary(&self) -> T; } @@ -533,12 +533,13 @@ impl ComplexOps for Complex { } } - // The GLS Complex module does not support `f32` operations. Thus we // convert back and forth to `f64`. impl ToC for Complex { fn unwrap(self) -> sys::gsl_complex { - sys::gsl_complex { dat: [self.re as f64, self.im as f64]} + sys::gsl_complex { + dat: [self.re as f64, self.im as f64], + } } } @@ -559,7 +560,10 @@ impl ToC for Complex { impl FromC> for sys::gsl_complex { fn wrap(self) -> Complex { let [re, im] = self.dat; - Complex { re: re as f32, im: im as f32 } + Complex { + re: re as f32, + im: im as f32, + } } } @@ -580,7 +584,7 @@ impl ComplexOps for Complex { unsafe { sys::gsl_complex_abs2(self.unwrap()) as f32 } } - fn logabs(&self) -> f32 { + fn logabs(&self) -> f32 { unsafe { sys::gsl_complex_logabs(self.unwrap()) as f32 } } @@ -765,7 +769,6 @@ impl ComplexOps for Complex { } } - #[cfg(test)] mod tests { // All these tests have been tested against the following C code: diff --git a/src/types/eigen_symmetric_workspace.rs b/src/types/eigen_symmetric_workspace.rs index 64d2aa24..23d846a9 100644 --- a/src/types/eigen_symmetric_workspace.rs +++ b/src/types/eigen_symmetric_workspace.rs @@ -890,10 +890,10 @@ fn eigen_symmetric_vworkspace() { // ``` #[test] fn eigen_hermitian_workspace() { + use crate::complex::ComplexOps; + use num_complex::Complex; use MatrixComplexF64; use VectorF64; - use num_complex::Complex; - use crate::complex::ComplexOps; let mut e = EigenHermitianWorkspace::new(3).unwrap(); let mut m = MatrixComplexF64::new(2, 2).unwrap(); @@ -945,8 +945,8 @@ fn eigen_hermitian_workspace() { // ``` #[test] fn eigen_hermitian_vworkspace() { - use num_complex::Complex; use crate::complex::ComplexOps; + use num_complex::Complex; let mut e = EigenHermitianVWorkspace::new(3).unwrap(); let mut m = MatrixComplexF64::new(2, 2).unwrap(); diff --git a/src/types/fast_fourier_transforms.rs b/src/types/fast_fourier_transforms.rs index 8593d360..093d1482 100644 --- a/src/types/fast_fourier_transforms.rs +++ b/src/types/fast_fourier_transforms.rs @@ -3,10 +3,7 @@ // use crate::ffi::FFI; -use crate::{ - vector::VectorMut, - Value -}; +use crate::{vector::VectorMut, Value}; use paste::paste; macro_rules! gsl_fft_wavetable { diff --git a/src/types/matrix_complex.rs b/src/types/matrix_complex.rs index 5a0c0b4f..1fdf0868 100644 --- a/src/types/matrix_complex.rs +++ b/src/types/matrix_complex.rs @@ -7,8 +7,8 @@ use crate::{ complex::{FromC, ToC}, Value, }; -use paste::paste; use num_complex::Complex; +use paste::paste; use std::fmt::{self, Debug, Formatter}; macro_rules! gsl_matrix_complex { diff --git a/src/types/vector_complex.rs b/src/types/vector_complex.rs index 6c789058..532b8398 100644 --- a/src/types/vector_complex.rs +++ b/src/types/vector_complex.rs @@ -7,16 +7,16 @@ use crate::{ complex::{FromC, ToC}, Value, }; +use num_complex::Complex; use paste::paste; use std::{ fmt::{self, Debug, Formatter}, marker::PhantomData, }; -use num_complex::Complex; macro_rules! gsl_vec_complex { ($rust_name:ident, $name:ident, $complex:ident, $rust_ty:ident) => { - paste! { + paste! { pub struct $rust_name { vec: *mut sys::$name, can_free: bool, From 45b8a3307ae8582666c7e1e295dd1598f5d19507 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 00:45:06 +0100 Subject: [PATCH 22/28] Fix errors reported by the custom checker --- src/types/complex.rs | 96 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/src/types/complex.rs b/src/types/complex.rs index 06cded69..c99a6c8f 100644 --- a/src/types/complex.rs +++ b/src/types/complex.rs @@ -332,194 +332,242 @@ pub trait ComplexOps { } impl ComplexOps for Complex { + #[doc(alias = "gsl_complex_rect")] fn rect(x: f64, y: f64) -> Complex { unsafe { sys::gsl_complex_rect(x, y).wrap() } } + #[doc(alias = "gsl_complex_polar")] fn polar(r: f64, theta: f64) -> Complex { unsafe { sys::gsl_complex_polar(r, theta).wrap() } } + #[doc(alias = "gsl_complex_abs")] fn abs(&self) -> f64 { unsafe { sys::gsl_complex_abs(self.unwrap()) } } + #[doc(alias = "gsl_complex_abs2")] fn abs2(&self) -> f64 { unsafe { sys::gsl_complex_abs2(self.unwrap()) } } + #[doc(alias = "gsl_complex_logabs")] fn logabs(&self) -> f64 { unsafe { sys::gsl_complex_logabs(self.unwrap()) } } + #[doc(alias = "gsl_complex_add")] fn add(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_add(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_sub")] fn sub(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_sub(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_mul")] fn mul(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_mul(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_div")] fn div(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_div(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_add_real")] fn add_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_add_real(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_sub_real")] fn sub_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_sub_real(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_mul_real")] fn mul_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_mul_real(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_div_real")] fn div_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_div_real(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_add_imag")] fn add_imag(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_add_imag(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_sub_imag")] fn sub_imag(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_sub_imag(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_mul_imag")] fn mul_imag(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_mul_imag(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_div_imag")] fn div_imag(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_div_imag(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_conjugate")] fn conjugate(&self) -> Complex { unsafe { sys::gsl_complex_conjugate(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_inverse")] fn inverse(&self) -> Complex { unsafe { sys::gsl_complex_inverse(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_negative")] fn negative(&self) -> Complex { unsafe { sys::gsl_complex_negative(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_sqrt_real")] fn sqrt_real(x: f64) -> Complex { unsafe { sys::gsl_complex_sqrt_real(x).wrap() } } + #[doc(alias = "gsl_complex_pow")] fn pow(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_pow(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_pow_real")] fn pow_real(&self, x: f64) -> Complex { unsafe { sys::gsl_complex_pow_real(self.unwrap(), x).wrap() } } + #[doc(alias = "gsl_complex_log_b")] fn log_b(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_log_b(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_sec")] fn sec(&self) -> Complex { unsafe { sys::gsl_complex_sec(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_csc")] fn csc(&self) -> Complex { unsafe { sys::gsl_complex_csc(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_cot")] fn cot(&self) -> Complex { unsafe { sys::gsl_complex_cot(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsin")] fn arcsin(&self) -> Complex { unsafe { sys::gsl_complex_arcsin(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsin_real")] fn arcsin_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arcsin_real(z).wrap() } } + #[doc(alias = "gsl_complex_arccos")] fn arccos(&self) -> Complex { unsafe { sys::gsl_complex_arccos(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccos_real")] fn arccos_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arccos_real(z).wrap() } } + #[doc(alias = "gsl_complex_arctan")] fn arctan(&self) -> Complex { unsafe { sys::gsl_complex_arctan(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsec")] fn arcsec(&self) -> Complex { unsafe { sys::gsl_complex_arcsec(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsec_real")] fn arcsec_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arcsec_real(z).wrap() } } + #[doc(alias = "gsl_complex_arccsc")] fn arccsc(&self) -> Complex { unsafe { sys::gsl_complex_arccsc(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccsc_real")] fn arccsc_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arccsc_real(z).wrap() } } + #[doc(alias = "gsl_complex_arccot")] fn arccot(&self) -> Complex { unsafe { sys::gsl_complex_arccot(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_sech")] fn sech(&self) -> Complex { unsafe { sys::gsl_complex_sech(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_csch")] fn csch(&self) -> Complex { unsafe { sys::gsl_complex_csch(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_coth")] fn coth(&self) -> Complex { unsafe { sys::gsl_complex_coth(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsinh")] fn arcsinh(&self) -> Complex { unsafe { sys::gsl_complex_arcsinh(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccosh")] fn arccosh(&self) -> Complex { unsafe { sys::gsl_complex_arccosh(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccosh_real")] fn arccosh_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arccosh_real(z).wrap() } } + #[doc(alias = "gsl_complex_arctanh")] fn arctanh(&self) -> Complex { unsafe { sys::gsl_complex_arctanh(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arctanh_real")] fn arctanh_real(z: f64) -> Complex { unsafe { sys::gsl_complex_arctanh_real(z).wrap() } } + #[doc(alias = "gsl_complex_arcsech")] fn arcsech(&self) -> Complex { unsafe { sys::gsl_complex_arcsech(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccsch")] fn arccsch(&self) -> Complex { unsafe { sys::gsl_complex_arccsch(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccoth")] fn arccoth(&self) -> Complex { unsafe { sys::gsl_complex_arccoth(self.unwrap()).wrap() } } @@ -568,194 +616,242 @@ impl FromC> for sys::gsl_complex { } impl ComplexOps for Complex { + #[doc(alias = "gsl_complex_rect")] fn rect(x: f32, y: f32) -> Complex { unsafe { sys::gsl_complex_rect(x as f64, y as f64).wrap() } } + #[doc(alias = "gsl_complex_polar")] fn polar(r: f32, theta: f32) -> Complex { unsafe { sys::gsl_complex_polar(r as f64, theta as f64).wrap() } } + #[doc(alias = "gsl_complex_abs")] fn abs(&self) -> f32 { unsafe { sys::gsl_complex_abs(self.unwrap()) as f32 } } + #[doc(alias = "gsl_complex_abs2")] fn abs2(&self) -> f32 { unsafe { sys::gsl_complex_abs2(self.unwrap()) as f32 } } + #[doc(alias = "gsl_complex_logabs")] fn logabs(&self) -> f32 { unsafe { sys::gsl_complex_logabs(self.unwrap()) as f32 } } + #[doc(alias = "gsl_complex_add")] fn add(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_add(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_sub")] fn sub(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_sub(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_mul")] fn mul(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_mul(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_div")] fn div(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_div(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_add_real")] fn add_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_add_real(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_sub_real")] fn sub_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_sub_real(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_mul_real")] fn mul_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_mul_real(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_div_real")] fn div_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_div_real(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_add_imag")] fn add_imag(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_add_imag(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_sub_imag")] fn sub_imag(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_sub_imag(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_mul_imag")] fn mul_imag(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_mul_imag(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_div_imag")] fn div_imag(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_div_imag(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_conjugate")] fn conjugate(&self) -> Complex { unsafe { sys::gsl_complex_conjugate(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_inverse")] fn inverse(&self) -> Complex { unsafe { sys::gsl_complex_inverse(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_negative")] fn negative(&self) -> Complex { unsafe { sys::gsl_complex_negative(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_sqrt_real")] fn sqrt_real(x: f32) -> Complex { unsafe { sys::gsl_complex_sqrt_real(x as f64).wrap() } } + #[doc(alias = "gsl_complex_pow")] fn pow(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_pow(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_pow_real")] fn pow_real(&self, x: f32) -> Complex { unsafe { sys::gsl_complex_pow_real(self.unwrap(), x as f64).wrap() } } + #[doc(alias = "gsl_complex_log_b")] fn log_b(&self, other: &Complex) -> Complex { unsafe { sys::gsl_complex_log_b(self.unwrap(), other.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_sec")] fn sec(&self) -> Complex { unsafe { sys::gsl_complex_sec(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_csc")] fn csc(&self) -> Complex { unsafe { sys::gsl_complex_csc(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_cot")] fn cot(&self) -> Complex { unsafe { sys::gsl_complex_cot(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsin")] fn arcsin(&self) -> Complex { unsafe { sys::gsl_complex_arcsin(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsin_real")] fn arcsin_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arcsin_real(z as f64).wrap() } } + #[doc(alias = "gsl_complex_arccos")] fn arccos(&self) -> Complex { unsafe { sys::gsl_complex_arccos(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccos_real")] fn arccos_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arccos_real(z as f64).wrap() } } + #[doc(alias = "gsl_complex_arctan")] fn arctan(&self) -> Complex { unsafe { sys::gsl_complex_arctan(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsec")] fn arcsec(&self) -> Complex { unsafe { sys::gsl_complex_arcsec(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsec_real")] fn arcsec_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arcsec_real(z as f64).wrap() } } + #[doc(alias = "gsl_complex_arccsc")] fn arccsc(&self) -> Complex { unsafe { sys::gsl_complex_arccsc(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccsc_real")] fn arccsc_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arccsc_real(z as f64).wrap() } } + #[doc(alias = "gsl_complex_arccot")] fn arccot(&self) -> Complex { unsafe { sys::gsl_complex_arccot(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_sech")] fn sech(&self) -> Complex { unsafe { sys::gsl_complex_sech(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_csch")] fn csch(&self) -> Complex { unsafe { sys::gsl_complex_csch(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_coth")] fn coth(&self) -> Complex { unsafe { sys::gsl_complex_coth(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arcsinh")] fn arcsinh(&self) -> Complex { unsafe { sys::gsl_complex_arcsinh(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccosh")] fn arccosh(&self) -> Complex { unsafe { sys::gsl_complex_arccosh(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccosh_real")] fn arccosh_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arccosh_real(z as f64).wrap() } } + #[doc(alias = "gsl_complex_arctanh")] fn arctanh(&self) -> Complex { unsafe { sys::gsl_complex_arctanh(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arctanh_real")] fn arctanh_real(z: f32) -> Complex { unsafe { sys::gsl_complex_arctanh_real(z as f64).wrap() } } + #[doc(alias = "gsl_complex_arcsech")] fn arcsech(&self) -> Complex { unsafe { sys::gsl_complex_arcsech(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccsch")] fn arccsch(&self) -> Complex { unsafe { sys::gsl_complex_arccsch(self.unwrap()).wrap() } } + #[doc(alias = "gsl_complex_arccoth")] fn arccoth(&self) -> Complex { unsafe { sys::gsl_complex_arccoth(self.unwrap()).wrap() } } From 9c359281c104aaa8419a143bdcb2842d70a395cd Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 00:52:20 +0100 Subject: [PATCH 23/28] Slightly relax a test on "drotg" --- src/blas.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/blas.rs b/src/blas.rs index cd5da230..3027bcff 100644 --- a/src/blas.rs +++ b/src/blas.rs @@ -463,9 +463,9 @@ pub mod level1 { #[test] fn test_drotg() { let (c, s, r) = drotg(3., 4.).unwrap(); - assert_eq!(c, 0.6); - assert_eq!(s, 0.8); - assert_eq!(r, 5.); + assert!((c - 0.6).abs() < 5e-16, "|{c} - 0.6| >= 5e-16"); + assert!((s - 0.8).abs() < 5e-16, "|{s} - 0.8| >= 5e-16"); + assert!((r - 5.).abs() < 1e-15, "|{r} - 5.| >= 1e-15"); } } } From eb99f9353ee5d1fec8bc96cecf4c7a9a5d73310a Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 01:04:04 +0100 Subject: [PATCH 24/28] Fix running examples --- .github/workflows/CI.yml | 4 ++-- examples/Cargo.toml | 23 +++++++++++++++-------- examples/README.md | 2 +- examples/run-examples.py | 2 +- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a50f6338..17b50d16 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -35,7 +35,7 @@ jobs: run: cargo test --features v2_7 - name: check examples working-directory: examples - run: cargo check --features GSL/v2_7 + run: cargo check --features v2_7 - name: run examples working-directory: examples run: python3 run-examples.py @@ -65,7 +65,7 @@ jobs: - run: cargo check --features v2_7 - name: check examples working-directory: examples - run: cargo check --features GSL/v2_7 + run: cargo check --features v2_7 fmt: name: rust fmt diff --git a/examples/Cargo.toml b/examples/Cargo.toml index ada63a10..2e6ad1ec 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -2,6 +2,13 @@ name = "gsl-examples" version = "0.4.6" authors = ["Guillaume Gomez "] +edition = "2021" + +[features] +v2_1 = ["GSL/v2_1"] +v2_2 = ["GSL/v2_2", "v2_1"] +v2_5 = ["GSL/v2_5", "v2_2"] +v2_7 = ["GSL/v2_7", "v2_5"] [dependencies.GSL] path = "../" @@ -56,22 +63,22 @@ path = "./fftmr.rs" [[bin]] name = "filt_edge" path = "./filt_edge.rs" -required-features = ["GSL/v2_5"] +required-features = ["v2_5"] [[bin]] name = "fitreg" path = "./fitreg.rs" -required-features = ["GSL/v2_1"] +required-features = ["v2_1"] [[bin]] name = "fitreg2" path = "./fitreg2.rs" -required-features = ["GSL/v2_1"] +required-features = ["v2_1"] [[bin]] name = "fitting" path = "./fitting.rs" -required-features = ["GSL/v2_5"] +required-features = ["v2_5"] [[bin]] name = "fitting3" @@ -80,12 +87,12 @@ path = "./fitting3.rs" [[bin]] name = "gaussfilt" path = "./gaussfilt.rs" -required-features = ["GSL/v2_5"] +required-features = ["v2_5"] [[bin]] name = "gaussfilt2" path = "./gaussfilt2.rs" -required-features = ["GSL/v2_5"] +required-features = ["v2_5"] [[bin]] name = "histogram2d" @@ -94,7 +101,7 @@ path = "./histogram2d.rs" [[bin]] name = "impulse" path = "./impulse.rs" -required-features = ["GSL/v2_5"] +required-features = ["v2_5"] [[bin]] name = "integration" @@ -111,7 +118,7 @@ path = "./intro.rs" [[bin]] name = "largefit" path = "./largefit.rs" -required-features = ["GSL/v2_2"] +required-features = ["v2_2"] [[bin]] name = "rng" diff --git a/examples/README.md b/examples/README.md index be733d8c..9b195ac9 100644 --- a/examples/README.md +++ b/examples/README.md @@ -12,7 +12,7 @@ Some examples might require a higher GSL version. `rgsl` supports versions throu So for example: ```bash -$ cargo run --bin largefit --features GSL/v2_2 +$ cargo run --bin largefit --features v2_2 ``` # Original examples diff --git a/examples/run-examples.py b/examples/run-examples.py index 1df95914..1a996ec0 100644 --- a/examples/run-examples.py +++ b/examples/run-examples.py @@ -5,7 +5,7 @@ def run_example(example_name): print("====> Running {}".format(example_name)) - command = ["cargo", "run", "--bin", example_name, "--features", "GSL/v2_7"] + command = ["cargo", "run", "--bin", example_name, "--features", "v2_7"] child = subprocess.Popen(command) child.communicate() if child.returncode != 0: From 506d95ac2cb849172e7680dfb5c77cfafd89e831 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 11:29:30 +0100 Subject: [PATCH 25/28] fast_fourier_transforms: use complex types --- Cargo.toml | 2 +- examples/Cargo.toml | 1 + examples/fftmr.rs | 25 +++++---------- src/types/fast_fourier_transforms.rs | 46 ++++++++++------------------ src/types/vector.rs | 24 +++++++++++++++ 5 files changed, 50 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fdfef2f4..fc2b406b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ edition = "2021" [dependencies] sys = { path = "gsl-sys", package = "GSL-sys", version = "3.0.0" } paste = "1.0" -num-complex = { version = "0.4.5", optional = true } +num-complex = { version = "0.4.6", optional = true } [features] default = ["complex"] diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 2e6ad1ec..dccff148 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -15,6 +15,7 @@ path = "../" [dependencies] libc = "~0.2" +num-complex = "0.4.6" [[bin]] name = "blas" diff --git a/examples/fftmr.rs b/examples/fftmr.rs index 460473e6..68d0c821 100644 --- a/examples/fftmr.rs +++ b/examples/fftmr.rs @@ -4,36 +4,25 @@ extern crate rgsl; +use num_complex::c64; use rgsl::{FftComplexF64WaveTable, FftComplexF64Workspace}; -// FIXME: Make the interface use complex numbers. -macro_rules! real { - ($z:ident, $i:expr) => { - $z[2 * ($i)] - }; -} -macro_rules! imag { - ($z:ident, $i:expr) => { - $z[2 * ($i) + 1] - }; -} - const N: usize = 128; fn main() { - let data = &mut [0.; 2 * N]; + let data = &mut [c64(0., 0.); N]; let wavetable = FftComplexF64WaveTable::new(N).expect("FftComplexF64WaveTable::new failed"); let mut workspace = FftComplexF64Workspace::new(N).expect("FftComplexF64Workspace::new failed"); - data[0] = 1.; + data[0].re = 1.; for i in 1..=10 { - real!(data, i) = 1.; - real!(data, N - i) = 1.; + data[i].re = 1.; + data[N - i].re = 1.; } for i in 0..N { - println!("{}: {} {}", i, real!(data, i), imag!(data, i)); + println!("{}: {}", i, data[i]); } println!(); @@ -44,6 +33,6 @@ fn main() { workspace.forward(data, &wavetable).unwrap(); for i in 0..N { - println!("{}: {} {}", i, real!(data, i), imag!(data, i)); + println!("{}: {}", i, data[i]); } } diff --git a/src/types/fast_fourier_transforms.rs b/src/types/fast_fourier_transforms.rs index 093d1482..0e8b1d92 100644 --- a/src/types/fast_fourier_transforms.rs +++ b/src/types/fast_fourier_transforms.rs @@ -3,7 +3,11 @@ // use crate::ffi::FFI; -use crate::{vector::VectorMut, Value}; +use crate::{ + vector::{ComplexSlice, VectorMut}, + Value, +}; +use num_complex::Complex; use paste::paste; macro_rules! gsl_fft_wavetable { @@ -74,20 +78,16 @@ impl $complex_rust_name { } #[doc(alias = $name $($extra)? _forward)] - pub fn forward + ?Sized>( + pub fn forward> + ?Sized>( &mut self, data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { - if V::len(data) % 2 == 1 { - panic!("{}: the length of the data must be even", - stringify!($complex_rust_name::forward)); - } let ret = unsafe { sys::[<$name $($extra)? _forward>]( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), - V::len(data) / 2, // FIXME: use complex vectors? + V::len(data), wavetable.unwrap_shared(), self.unwrap_unique(), ) @@ -96,21 +96,17 @@ impl $complex_rust_name { } #[doc(alias = $name $($extra)? _transform)] - pub fn transform + ?Sized>( + pub fn transform> + ?Sized>( &mut self, data: &mut V, wavetable: &$rust_name, sign: crate::FftDirection, ) -> Result<(), Value> { - if V::len(data) % 2 == 1 { - panic!("{}: the length of the data must be even", - stringify!($complex_rust_name::transform)); - } let ret = unsafe { sys::[<$name $($extra)? _transform>]( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), - V::len(data) / 2, // FIXME: use complex vectors? + V::len(data), wavetable.unwrap_shared(), self.unwrap_unique(), sign.into(), @@ -120,20 +116,16 @@ impl $complex_rust_name { } #[doc(alias = $name $($extra)? _backward)] - pub fn backward + ?Sized>( + pub fn backward> + ?Sized>( &mut self, data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { - if V::len(data) % 2 == 1 { - panic!("{}: the length of the data must be even", - stringify!($complex_rust_name::backward)); - } let ret = unsafe { sys::[<$name $($extra)? _backward>]( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), - V::len(data) / 2, // FIXME: use complex vectors? + V::len(data), wavetable.unwrap_shared(), self.unwrap_unique(), ) @@ -142,20 +134,16 @@ impl $complex_rust_name { } #[doc(alias = $name $($extra)? _inverse)] - pub fn inverse + ?Sized>( + pub fn inverse> + ?Sized>( &mut self, data: &mut V, wavetable: &$rust_name, ) -> Result<(), Value> { - if V::len(data) % 2 == 1 { - panic!("{}: the length of the data must be even", - stringify!($complex_rust_name::inverse)); - } let ret = unsafe { sys::[<$name $($extra)? _inverse>]( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), - V::len(data) / 2, // FIXME: use complex vectors? + V::len(data), wavetable.unwrap_shared(), self.unwrap_unique(), ) diff --git a/src/types/vector.rs b/src/types/vector.rs index 120e2aeb..e6b06544 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -818,3 +818,27 @@ impl_AsRef!(f64); impl_AsRef!(Complex); #[cfg(feature = "complex")] impl_AsRef!(Complex); + +// Helper trait to convert Complex slices. +pub(crate) trait ComplexSlice { + // fn as_ptr_fXX(&self) -> *const T; + fn as_mut_ptr_fXX(&mut self) -> *mut T; +} + +macro_rules! impl_ComplexSlice { + ($ty: ty) => { + impl ComplexSlice<$ty> for [Complex<$ty>] { + // fn as_ptr_fXX(&self) -> *const $ty { + // // Complex layout is two consecutive f64 values. + // self.as_ptr() as *const $ty + // } + + fn as_mut_ptr_fXX(&mut self) -> *mut $ty { + self.as_mut_ptr() as *mut $ty + } + } + }; +} + +impl_ComplexSlice!(f64); +impl_ComplexSlice!(f32); From 67a208d7c967ecf858895c230f3ffc813b14bc73 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 14:37:10 +0100 Subject: [PATCH 26/28] fft: Use complex types --- examples/fft.rs | 33 ++++++++------------------------- src/fft.rs | 47 ++++++++++++++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 42 deletions(-) diff --git a/examples/fft.rs b/examples/fft.rs index 909c99d0..fc92f26e 100644 --- a/examples/fft.rs +++ b/examples/fft.rs @@ -2,47 +2,30 @@ // A rust binding for the GSL library by Guillaume Gomez (guillaume1.gomez@gmail.com) // -extern crate rgsl; - +use num_complex::c64; use rgsl::fft; -macro_rules! real { - ($z:ident, $i:expr) => { - $z[2 * ($i)] - }; -} -macro_rules! imag { - ($z:ident, $i:expr) => { - $z[2 * ($i) + 1] - }; -} - const N: usize = 128; fn main() { - let data = &mut [0.; 2 * N]; + let data = &mut [c64(0., 0.); N]; - real!(data, 0) = 1.; + data[0].re = 1.; for i in 1..=10 { - real!(data, i) = 1.; - real!(data, N - i) = 1.; + data[i].re = 1.; + data[N - i].re = 1.; } for i in 0..N { - println!("{} {} {}", i, real!(data, i), imag!(data, i)); + println!("{} {}", i, data[i]); } println!(); println!(); - fft::radix2::forward(data, 1, N).unwrap(); + fft::radix2::forward(data).unwrap(); for i in 0..N { - println!( - "{} {} {}", - i, - real!(data, i) / 128f64.sqrt(), - imag!(data, i) / 128f64.sqrt() - ); + println!("{} {}", i, data[i] / 128f64.sqrt()); } } diff --git a/src/fft.rs b/src/fft.rs index 2cb820db..ea58250a 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -165,22 +165,35 @@ is desirable for better locality of memory accesses). /// /// The functions return a value of crate::Value::Success if no errors were detected, or Value::Dom if the length n is not a power of two. pub mod radix2 { - use crate::{vector::VectorMut, Value}; + use crate::{ + vector::{ComplexSlice, VectorMut}, + Value, + }; + use num_complex::Complex; #[doc(alias = "gsl_fft_complex_radix2_forward")] - pub fn forward(data: &mut [f64], stride: usize, n: usize) -> Result<(), Value> { - let ret = unsafe { sys::gsl_fft_complex_radix2_forward(data.as_mut_ptr(), stride, n) }; + pub fn forward(data: &mut V) -> Result<(), Value> + where + V: VectorMut> + ?Sized, + { + let ret = unsafe { + sys::gsl_fft_complex_radix2_forward( + V::as_mut_slice(data).as_mut_ptr_fXX(), + V::stride(data), + V::len(data), + ) + }; result_handler!(ret, ()) } #[doc(alias = "gsl_fft_complex_radix2_transform")] - pub fn transform + ?Sized>( + pub fn transform> + ?Sized>( data: &mut V, sign: crate::FftDirection, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_fft_complex_radix2_transform( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), V::len(data), sign.into(), @@ -192,11 +205,11 @@ pub mod radix2 { #[doc(alias = "gsl_fft_complex_radix2_backward")] pub fn backward(data: &mut V) -> Result<(), Value> where - V: VectorMut + ?Sized, + V: VectorMut> + ?Sized, { let ret = unsafe { sys::gsl_fft_complex_radix2_backward( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), V::len(data), ) @@ -207,11 +220,11 @@ pub mod radix2 { #[doc(alias = "gsl_fft_complex_radix2_inverse")] pub fn inverse(data: &mut V) -> Result<(), Value> where - V: VectorMut + ?Sized, + V: VectorMut> + ?Sized, { let ret = unsafe { sys::gsl_fft_complex_radix2_inverse( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), V::len(data), ) @@ -223,11 +236,11 @@ pub mod radix2 { #[doc(alias = "gsl_fft_complex_radix2_dif_forward")] pub fn dif_forward(data: &mut V) -> Result<(), Value> where - V: VectorMut + ?Sized, + V: VectorMut> + ?Sized, { let ret = unsafe { sys::gsl_fft_complex_radix2_dif_forward( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), V::len(data), ) @@ -237,13 +250,13 @@ pub mod radix2 { /// This is decimation-in-frequency version of the radix-2 FFT function. #[doc(alias = "gsl_fft_complex_radix2_dif_transform")] - pub fn dif_transform + ?Sized>( + pub fn dif_transform> + ?Sized>( data: &mut V, sign: crate::FftDirection, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_fft_complex_radix2_dif_transform( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), V::len(data), sign.into(), @@ -256,11 +269,11 @@ pub mod radix2 { #[doc(alias = "gsl_fft_complex_radix2_dif_backward")] pub fn dif_backward(data: &mut V) -> Result<(), Value> where - V: VectorMut + ?Sized, + V: VectorMut> + ?Sized, { let ret = unsafe { sys::gsl_fft_complex_radix2_dif_backward( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), V::len(data), ) @@ -272,11 +285,11 @@ pub mod radix2 { #[doc(alias = "gsl_fft_complex_radix2_dif_inverse")] pub fn dif_inverse(data: &mut V) -> Result<(), Value> where - V: VectorMut + ?Sized, + V: VectorMut> + ?Sized, { let ret = unsafe { sys::gsl_fft_complex_radix2_dif_inverse( - V::as_mut_slice(data).as_mut_ptr(), + V::as_mut_slice(data).as_mut_ptr_fXX(), V::stride(data), V::len(data), ) From 0d11f49c673112000a4179b2595d02e583c2a6f7 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 14:48:53 +0100 Subject: [PATCH 27/28] Do not consider examples as a separate crate The benefit is that all examples are compiled with "cargo test" which avoids bit rotting (without the need for CI). --- .github/workflows/CI.yml | 2 - Cargo.toml | 33 +++++ examples/Cargo.toml | 134 -------------------- examples/README.md | 4 +- examples/run-examples.py => run-examples.py | 4 +- 5 files changed, 37 insertions(+), 140 deletions(-) delete mode 100644 examples/Cargo.toml rename examples/run-examples.py => run-examples.py (79%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 17b50d16..25ea4568 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -34,10 +34,8 @@ jobs: - name: run tests run: cargo test --features v2_7 - name: check examples - working-directory: examples run: cargo check --features v2_7 - name: run examples - working-directory: examples run: python3 run-examples.py build-osx: diff --git a/Cargo.toml b/Cargo.toml index fc2b406b..e2a3da27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,3 +35,36 @@ rustdoc-args = ["--generate-link-to-definition"] [lib] name = "rgsl" crate-type = ["dylib", "rlib"] + +# Examples with special requirements +[[example]] +name = "filt_edge" +required-features = ["v2_5"] + +[[example]] +name = "fitreg" +required-features = ["v2_5"] + +[[example]] +name = "fitreg2" +required-features = ["v2_5"] + +[[example]] +name = "fitting" +required-features = ["v2_5"] + +[[example]] +name = "gaussfilt" +required-features = ["v2_5"] + +[[example]] +name = "gaussfilt2" +required-features = ["v2_5"] + +[[example]] +name = "impulse" +required-features = ["v2_5"] + +[[example]] +name = "largefit" +required-features = ["v2_2"] diff --git a/examples/Cargo.toml b/examples/Cargo.toml deleted file mode 100644 index dccff148..00000000 --- a/examples/Cargo.toml +++ /dev/null @@ -1,134 +0,0 @@ -[package] -name = "gsl-examples" -version = "0.4.6" -authors = ["Guillaume Gomez "] -edition = "2021" - -[features] -v2_1 = ["GSL/v2_1"] -v2_2 = ["GSL/v2_2", "v2_1"] -v2_5 = ["GSL/v2_5", "v2_2"] -v2_7 = ["GSL/v2_7", "v2_5"] - -[dependencies.GSL] -path = "../" - -[dependencies] -libc = "~0.2" -num-complex = "0.4.6" - -[[bin]] -name = "blas" -path = "./blas.rs" - -[[bin]] -name = "bspline" -path = "./bspline.rs" - -[[bin]] -name = "cblas" -path = "./cblas.rs" - -[[bin]] -name = "cdf" -path = "./cdf.rs" - -[[bin]] -name = "chebyshev_approximation" -path = "./chebyshev_approximation.rs" - -[[bin]] -name = "combination" -path = "./combination.rs" - -[[bin]] -name = "diff" -path = "./diff.rs" - -[[bin]] -name = "eigen" -path = "./eigen.rs" - -[[bin]] -name = "eigen_nonsymm" -path = "./eigen_nonsymm.rs" - -[[bin]] -name = "fft" -path = "./fft.rs" - -[[bin]] -name = "fftmr" -path = "./fftmr.rs" - -[[bin]] -name = "filt_edge" -path = "./filt_edge.rs" -required-features = ["v2_5"] - -[[bin]] -name = "fitreg" -path = "./fitreg.rs" -required-features = ["v2_1"] - -[[bin]] -name = "fitreg2" -path = "./fitreg2.rs" -required-features = ["v2_1"] - -[[bin]] -name = "fitting" -path = "./fitting.rs" -required-features = ["v2_5"] - -[[bin]] -name = "fitting3" -path = "./fitting3.rs" - -[[bin]] -name = "gaussfilt" -path = "./gaussfilt.rs" -required-features = ["v2_5"] - -[[bin]] -name = "gaussfilt2" -path = "./gaussfilt2.rs" -required-features = ["v2_5"] - -[[bin]] -name = "histogram2d" -path = "./histogram2d.rs" - -[[bin]] -name = "impulse" -path = "./impulse.rs" -required-features = ["v2_5"] - -[[bin]] -name = "integration" -path = "./integration.rs" - -[[bin]] -name = "integration2" -path = "./integration2.rs" - -[[bin]] -name = "intro" -path = "./intro.rs" - -[[bin]] -name = "largefit" -path = "./largefit.rs" -required-features = ["v2_2"] - -[[bin]] -name = "rng" -path = "./rng.rs" - -[[bin]] -name = "statistics" -path = "./statistics.rs" - -[[bin]] -name = "vectors_and_matrices" -path = "./vectors_and_matrices.rs" diff --git a/examples/README.md b/examples/README.md index 9b195ac9..85015f3e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,7 +3,7 @@ This folder contains the `rgsl` examples. To run one, just use `cargo`: ```bash -$ cargo run --bin intro +$ cargo run --example intro ``` And that's it! @@ -12,7 +12,7 @@ Some examples might require a higher GSL version. `rgsl` supports versions throu So for example: ```bash -$ cargo run --bin largefit --features v2_2 +$ cargo run --example largefit --features v2_2 ``` # Original examples diff --git a/examples/run-examples.py b/run-examples.py similarity index 79% rename from examples/run-examples.py rename to run-examples.py index 1a996ec0..d1a2f193 100644 --- a/examples/run-examples.py +++ b/run-examples.py @@ -5,7 +5,7 @@ def run_example(example_name): print("====> Running {}".format(example_name)) - command = ["cargo", "run", "--bin", example_name, "--features", "v2_7"] + command = ["cargo", "run", "--example", example_name, "--features", "v2_7"] child = subprocess.Popen(command) child.communicate() if child.returncode != 0: @@ -16,7 +16,7 @@ def run_example(example_name): def run_examples(): ret = 0 - for example in [f for f in os.listdir('.') if os.path.isfile(f)]: + for example in [f for f in os.listdir('examples')]: if not example.endswith('.rs'): continue if not run_example(example[:-3]): From 0f026675f4e7fcd42ac7e11e547b91a117cd0036 Mon Sep 17 00:00:00 2001 From: Christophe Troestler Date: Thu, 26 Dec 2024 16:03:29 +0100 Subject: [PATCH 28/28] Use the Vector trait for statistics, permutations, and wavelets --- examples/statistics.rs | 8 +- src/statistics.rs | 645 ++++++++++++++++++++++++++++---------- src/types/permutation.rs | 26 +- src/wavelet_transforms.rs | 38 +-- 4 files changed, 529 insertions(+), 188 deletions(-) diff --git a/examples/statistics.rs b/examples/statistics.rs index f36b7cc0..9f25773b 100644 --- a/examples/statistics.rs +++ b/examples/statistics.rs @@ -7,10 +7,10 @@ extern crate rgsl; fn main() { let data: [f64; 5] = [17.2, 18.1, 16.5, 18.3, 12.6]; - let mean = rgsl::statistics::mean(&data, 1, 5); - let variance = rgsl::statistics::variance(&data, 1, 5); - let largest = rgsl::statistics::max(&data, 1, 5); - let smallest = rgsl::statistics::min(&data, 1, 5); + let mean = rgsl::statistics::mean(&data); + let variance = rgsl::statistics::variance(&data); + let largest = rgsl::statistics::max(&data); + let smallest = rgsl::statistics::min(&data); println!( "The dataset is {}, {}, {}, {}, {}", diff --git a/src/statistics.rs b/src/statistics.rs index 95521a33..b93e648a 100644 --- a/src/statistics.rs +++ b/src/statistics.rs @@ -44,6 +44,8 @@ Review of Particle Properties R.M. Barnett et al., Physical Review D54, 1 (1996) The Review of Particle Physics is available online at the website . !*/ +use crate::vector::{check_equal_len, Vector}; + /// This function returns the arithmetic mean of data, a dataset of length n with stride stride. The /// arithmetic mean, or sample mean, is denoted by \Hat\mu and defined as, /// @@ -52,8 +54,11 @@ The Review of Particle Physics is available online at the website f64 { - unsafe { sys::gsl_stats_mean(data.as_ptr(), stride, n) } +pub fn mean(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_mean(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function returns the estimated, or sample, variance of data, a dataset of length n with @@ -69,8 +74,11 @@ pub fn mean(data: &[f64], stride: usize, n: usize) -> f64 { /// This function computes the mean via a call to gsl_stats_mean. If you have already computed the /// mean then you can pass it directly to gsl_stats_variance_m. #[doc(alias = "gsl_stats_variance")] -pub fn variance(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_variance(data.as_ptr(), stride, n) } +pub fn variance(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_variance(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function returns the sample variance of data relative to the given value of mean. The @@ -78,22 +86,45 @@ pub fn variance(data: &[f64], stride: usize, n: usize) -> f64 { /// /// \Hat\sigma^2 = (1/(N-1)) \sum (x_i - mean)^2 #[doc(alias = "gsl_stats_variance_m")] -pub fn variance_m(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { - unsafe { sys::gsl_stats_variance_m(data.as_ptr(), stride, n, mean) } +pub fn variance_m(data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_variance_m( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + ) + } } /// The standard deviation is defined as the square root of the variance. This function returns the /// square root of the corresponding variance functions above. #[doc(alias = "gsl_stats_sd")] -pub fn sd(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_sd(data.as_ptr(), stride, n) } +pub fn sd(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_sd(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// The standard deviation is defined as the square root of the variance. This function returns the /// square root of the corresponding variance functions above. #[doc(alias = "gsl_stats_sd_m")] -pub fn sd_m(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { - unsafe { sys::gsl_stats_sd_m(data.as_ptr(), stride, n, mean) } +pub fn sd_m(data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_sd_m( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + ) + } } /// This function returns the total sum of squares (TSS) of data about the mean. For gsl_stats_tss_m @@ -102,8 +133,11 @@ pub fn sd_m(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { /// /// TSS = \sum (x_i - mean)^2 #[doc(alias = "gsl_stats_tss")] -pub fn tss(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_tss(data.as_ptr(), stride, n) } +pub fn tss(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_tss(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function returns the total sum of squares (TSS) of data about the mean. For gsl_stats_tss_m @@ -112,8 +146,18 @@ pub fn tss(data: &[f64], stride: usize, n: usize) -> f64 { /// /// TSS = \sum (x_i - mean)^2 #[doc(alias = "gsl_stats_tss_m")] -pub fn tss_m(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { - unsafe { sys::gsl_stats_tss_m(data.as_ptr(), stride, n, mean) } +pub fn tss_m(data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_tss_m( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + ) + } } /// This function computes an unbiased estimate of the variance of data when the population mean @@ -123,15 +167,35 @@ pub fn tss_m(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { /// /// \Hat\sigma^2 = (1/N) \sum (x_i - \mu)^2 #[doc(alias = "gsl_stats_variance_with_fixed_mean")] -pub fn variance_with_fixed_mean(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { - unsafe { sys::gsl_stats_variance_with_fixed_mean(data.as_ptr(), stride, n, mean) } +pub fn variance_with_fixed_mean(data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_variance_with_fixed_mean( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + ) + } } /// This function calculates the standard deviation of data for a fixed population mean mean. The /// result is the square root of the corresponding variance function. #[doc(alias = "gsl_stats_sd_with_fixed_mean")] -pub fn sd_with_fixed_mean(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { - unsafe { sys::gsl_stats_sd_with_fixed_mean(data.as_ptr(), stride, n, mean) } +pub fn sd_with_fixed_mean(data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_sd_with_fixed_mean( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + ) + } } /// This function computes the absolute deviation from the mean of data, a dataset of length n with @@ -143,8 +207,11 @@ pub fn sd_with_fixed_mean(data: &[f64], stride: usize, n: usize, mean: f64) -> f /// more robust measure of the width of a distribution than the variance. This function computes the /// mean of data via a call to gsl_stats_mean. #[doc(alias = "gsl_stats_absdev")] -pub fn absdev(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_absdev(data.as_ptr(), stride, n) } +pub fn absdev(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_absdev(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function computes the absolute deviation of the dataset data relative to the given value of @@ -156,8 +223,18 @@ pub fn absdev(data: &[f64], stride: usize, n: usize) -> f64 { /// recomputing it), or wish to calculate the absolute deviation relative to another value (such as /// zero, or the median). #[doc(alias = "gsl_stats_absdev_m")] -pub fn absdev_m(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { - unsafe { sys::gsl_stats_absdev_m(data.as_ptr(), stride, n, mean) } +pub fn absdev_m(data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_absdev_m( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + ) + } } /// This function computes the skewness of data, a dataset of length n with stride stride. The @@ -171,8 +248,11 @@ pub fn absdev_m(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { /// The function computes the mean and estimated standard deviation of data via calls to [`mean`] /// and [`sd`]. #[doc(alias = "gsl_stats_skew")] -pub fn skew(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_skew(data.as_ptr(), stride, n) } +pub fn skew(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_skew(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function computes the skewness of the dataset data using the given values of the mean mean @@ -183,8 +263,19 @@ pub fn skew(data: &[f64], stride: usize, n: usize) -> f64 { /// These functions are useful if you have already computed the mean and standard deviation of data /// and want to avoid recomputing them. #[doc(alias = "gsl_stats_skew_m_sd")] -pub fn skew_m_sd(data: &[f64], stride: usize, n: usize, mean: f64, sd: f64) -> f64 { - unsafe { sys::gsl_stats_skew_m_sd(data.as_ptr(), stride, n, mean, sd) } +pub fn skew_m_sd(data: &V, mean: f64, sd: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_skew_m_sd( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + sd, + ) + } } /// This function computes the kurtosis of data, a dataset of length n with stride stride. The @@ -195,8 +286,11 @@ pub fn skew_m_sd(data: &[f64], stride: usize, n: usize, mean: f64, sd: f64) -> f /// The kurtosis measures how sharply peaked a distribution is, relative to its width. The kurtosis /// is normalized to zero for a Gaussian distribution. #[doc(alias = "gsl_stats_kurtosis")] -pub fn kurtosis(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_kurtosis(data.as_ptr(), stride, n) } +pub fn kurtosis(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_kurtosis(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function computes the kurtosis of the dataset data using the given values of the mean mean @@ -207,8 +301,19 @@ pub fn kurtosis(data: &[f64], stride: usize, n: usize) -> f64 { /// This function is useful if you have already computed the mean and standard deviation of data and /// want to avoid recomputing them. #[doc(alias = "gsl_stats_kurtosis_m_sd")] -pub fn kurtosis_m_sd(data: &[f64], stride: usize, n: usize, mean: f64, sd: f64) -> f64 { - unsafe { sys::gsl_stats_kurtosis_m_sd(data.as_ptr(), stride, n, mean, sd) } +pub fn kurtosis_m_sd(data: &V, mean: f64, sd: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_kurtosis_m_sd( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + sd, + ) + } } /// This function computes the lag-1 autocorrelation of the dataset data. @@ -217,15 +322,34 @@ pub fn kurtosis_m_sd(data: &[f64], stride: usize, n: usize, mean: f64, sd: f64) /// \over /// \sum_{i = 1}^{n} (x_{i} - \Hat\mu) (x_{i} - \Hat\mu)} #[doc(alias = "gsl_stats_lag1_autocorrelation")] -pub fn lag1_autocorrelation(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_lag1_autocorrelation(data.as_ptr(), stride, n) } +pub fn lag1_autocorrelation(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_lag1_autocorrelation( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function computes the lag-1 autocorrelation of the dataset data using the given value of /// the mean mean. #[doc(alias = "gsl_stats_lag1_autocorrelation_m")] -pub fn lag1_autocorrelation_m(data: &[f64], stride: usize, n: usize, mean: f64) -> f64 { - unsafe { sys::gsl_stats_lag1_autocorrelation_m(data.as_ptr(), stride, n, mean) } +pub fn lag1_autocorrelation_m(data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_lag1_autocorrelation_m( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + ) + } } /// This function computes the covariance of the datasets data1 and data2 which must both be of the @@ -233,30 +357,38 @@ pub fn lag1_autocorrelation_m(data: &[f64], stride: usize, n: usize, mean: f64) /// /// covar = (1/(n - 1)) \sum_{i = 1}^{n} (x_i - \Hat x) (y_i - \Hat y) #[doc(alias = "gsl_stats_covariance")] -pub fn covariance(data1: &[f64], stride1: usize, data2: &[f64], stride2: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_covariance(data1.as_ptr(), stride1, data2.as_ptr(), stride2, n) } +pub fn covariance(data1: &V, data2: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(data1, data2).unwrap(); + unsafe { + sys::gsl_stats_covariance( + V::as_slice(data1).as_ptr(), + V::stride(data1), + V::as_slice(data2).as_ptr(), + V::stride(data2), + V::len(data1), + ) + } } /// This function computes the covariance of the datasets data1 and data2 using the given values of /// the means, mean1 and mean2. This is useful if you have already computed the means of data1 and /// data2 and want to avoid recomputing them. #[doc(alias = "gsl_stats_covariance_m")] -pub fn covariance_m( - data1: &[f64], - stride1: usize, - data2: &[f64], - stride2: usize, - n: usize, - mean1: f64, - mean2: f64, -) -> f64 { +pub fn covariance_m(data1: &V, data2: &V, mean1: f64, mean2: f64) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(data1, data2).unwrap(); unsafe { sys::gsl_stats_covariance_m( - data1.as_ptr(), - stride1, - data2.as_ptr(), - stride2, - n, + V::as_slice(data1).as_ptr(), + V::stride(data1), + V::as_slice(data2).as_ptr(), + V::stride(data2), + V::len(data1), mean1, mean2, ) @@ -272,8 +404,20 @@ pub fn covariance_m( /// \sqrt{1/(n-1) \sum (x_i - \Hat x)^2} \sqrt{1/(n-1) \sum (y_i - \Hat y)^2} /// } #[doc(alias = "gsl_stats_correlation")] -pub fn correlation(data1: &[f64], stride1: usize, data2: &[f64], stride2: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_correlation(data1.as_ptr(), stride1, data2.as_ptr(), stride2, n) } +pub fn correlation(data1: &V, data2: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(data1, data2).unwrap(); + unsafe { + sys::gsl_stats_correlation( + V::as_slice(data1).as_ptr(), + V::stride(data1), + V::as_slice(data2).as_ptr(), + V::stride(data2), + V::len(data1), + ) + } } /// This function computes the Spearman rank correlation coefficient between the datasets data1 and @@ -282,21 +426,21 @@ pub fn correlation(data1: &[f64], stride1: usize, data2: &[f64], stride2: usize, /// correlation between the ranked vectors x_R and y_R, where ranks are defined to be the average of /// the positions of an element in the ascending order of the values. #[doc(alias = "gsl_stats_spearman")] -pub fn spearman( - data1: &[f64], - stride1: usize, - data2: &[f64], - stride2: usize, - n: usize, - work: &mut [f64], -) -> f64 { +pub fn spearman(data1: &V, data2: &V, work: &mut [f64]) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(data1, data2).unwrap(); + if work.len() < 2 * V::len(data1) { + panic!("gsl::statistics::spearman: `work` is too small"); + } unsafe { sys::gsl_stats_spearman( - data1.as_ptr(), - stride1, - data2.as_ptr(), - stride2, - n, + V::as_slice(data1).as_ptr(), + V::stride(data1), + V::as_slice(data2).as_ptr(), + V::stride(data2), + V::len(data1), work.as_mut_ptr(), ) } @@ -307,8 +451,20 @@ pub fn spearman( /// /// \Hat\mu = (\sum w_i x_i) / (\sum w_i) #[doc(alias = "gsl_stats_wmean")] -pub fn wmean(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_wmean(w.as_ptr(), wstride, data.as_ptr(), stride, n) } +pub fn wmean(w: &V, data: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wmean( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function returns the estimated variance of the dataset data with stride stride and length @@ -321,36 +477,78 @@ pub fn wmean(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) - /// Note that this expression reduces to an unweighted variance with the familiar 1/(N-1) factor /// when there are N equal non-zero weights. #[doc(alias = "gsl_stats_wvariance")] -pub fn wvariance(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_wvariance(w.as_ptr(), wstride, data.as_ptr(), stride, n) } +pub fn wvariance(w: &V, data: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wvariance( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function returns the estimated variance of the weighted dataset data using the given /// weighted mean wmean. #[doc(alias = "gsl_stats_wvariance_m")] -pub fn wvariance_m( - w: &[f64], - wstride: usize, - data: &[f64], - stride: usize, - n: usize, - wmean: f64, -) -> f64 { - unsafe { sys::gsl_stats_wvariance_m(w.as_ptr(), wstride, data.as_ptr(), stride, n, wmean) } +pub fn wvariance_m(w: &V, data: &V, wmean: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_wvariance_m( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + wmean, + ) + } } /// The standard deviation is defined as the square root of the variance. This function returns the /// square root of the corresponding variance function [`wvariance`] above. #[doc(alias = "gsl_stats_wsd")] -pub fn wsd(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_wsd(w.as_ptr(), wstride, data.as_ptr(), stride, n) } +pub fn wsd(w: &V, data: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wsd( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function returns the square root of the corresponding variance function /// [`wvariance_m`] above. #[doc(alias = "gsl_stats_wsd_m")] -pub fn wsd_m(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize, wmean: f64) -> f64 { - unsafe { sys::gsl_stats_wsd_m(w.as_ptr(), wstride, data.as_ptr(), stride, n, wmean) } +pub fn wsd_m(w: &V, data: &V, wmean: f64) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wsd_m( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + wmean, + ) + } } /// This function computes an unbiased estimate of the variance of the weighted dataset data when @@ -359,21 +557,18 @@ pub fn wsd_m(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize, w /// /// \Hat\sigma^2 = (\sum w_i (x_i - \mu)^2) / (\sum w_i) #[doc(alias = "gsl_stats_wvariance_with_fixed_mean")] -pub fn wvariance_with_fixed_mean( - w: &[f64], - wstride: usize, - data: &[f64], - stride: usize, - n: usize, - mean: f64, -) -> f64 { +pub fn wvariance_with_fixed_mean(w: &V, data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); unsafe { sys::gsl_stats_wvariance_with_fixed_mean( - w.as_ptr(), - wstride, - data.as_ptr(), - stride, - n, + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), mean, ) } @@ -382,16 +577,20 @@ pub fn wvariance_with_fixed_mean( /// The standard deviation is defined as the square root of the variance. This function returns the /// square root of the corresponding variance function above. #[doc(alias = "gsl_stats_wsd_with_fixed_mean")] -pub fn wsd_with_fixed_mean( - w: &[f64], - wstride: usize, - data: &[f64], - stride: usize, - n: usize, - mean: f64, -) -> f64 { +pub fn wsd_with_fixed_mean(w: &V, data: &V, mean: f64) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); unsafe { - sys::gsl_stats_wsd_with_fixed_mean(w.as_ptr(), wstride, data.as_ptr(), stride, n, mean) + sys::gsl_stats_wsd_with_fixed_mean( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + mean, + ) } } @@ -401,8 +600,20 @@ pub fn wsd_with_fixed_mean( /// /// TSS = \sum w_i (x_i - wmean)^2 #[doc(alias = "gsl_stats_wtss")] -pub fn wtss(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_wtss(w.as_ptr(), wstride, data.as_ptr(), stride, n) } +pub fn wtss(w: &V, data: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wtss( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function returns the weighted total sum of squares (TSS) of data about the weighted mean. @@ -411,8 +622,21 @@ pub fn wtss(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) -> /// /// TSS = \sum w_i (x_i - wmean)^2 #[doc(alias = "gsl_stats_wtss_m")] -pub fn wtss_m(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize, wmean: f64) -> f64 { - unsafe { sys::gsl_stats_wtss_m(w.as_ptr(), wstride, data.as_ptr(), stride, n, wmean) } +pub fn wtss_m(w: &V, data: &V, wmean: f64) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wtss_m( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + wmean, + ) + } } /// This function computes the weighted absolute deviation from the weighted mean of data. The absolute deviation from the mean is defined @@ -420,68 +644,120 @@ pub fn wtss_m(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize, /// /// absdev = (\sum w_i |x_i - \Hat\mu|) / (\sum w_i) #[doc(alias = "gsl_stats_wabsdev")] -pub fn wabsdev(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_wabsdev(w.as_ptr(), wstride, data.as_ptr(), stride, n) } +pub fn wabsdev(w: &V, data: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wabsdev( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function computes the absolute deviation of the weighted dataset data about the given weighted mean wmean. #[doc(alias = "gsl_stats_wabsdev_m")] -pub fn wabsdev_m( - w: &[f64], - wstride: usize, - data: &[f64], - stride: usize, - n: usize, - wmean: f64, -) -> f64 { - unsafe { sys::gsl_stats_wabsdev_m(w.as_ptr(), wstride, data.as_ptr(), stride, n, wmean) } +pub fn wabsdev_m(w: &V, data: &V, wmean: f64) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wabsdev_m( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + wmean, + ) + } } /// This function computes the weighted skewness of the dataset data. /// /// skew = (\sum w_i ((x_i - \Hat x)/\Hat \sigma)^3) / (\sum w_i) #[doc(alias = "gsl_stats_wskew")] -pub fn wskew(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_wskew(w.as_ptr(), wstride, data.as_ptr(), stride, n) } +pub fn wskew(w: &V, data: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wskew( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function computes the weighted skewness of the dataset data using the given values of the /// weighted mean and weighted standard deviation, wmean and wsd. #[doc(alias = "gsl_stats_wskew_m_sd")] -pub fn wskew_m_sd( - w: &[f64], - wstride: usize, - data: &[f64], - stride: usize, - n: usize, - wmean: f64, - wsd: f64, -) -> f64 { - unsafe { sys::gsl_stats_wskew_m_sd(w.as_ptr(), wstride, data.as_ptr(), stride, n, wmean, wsd) } +pub fn wskew_m_sd(w: &V, data: &V, wmean: f64, wsd: f64) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wskew_m_sd( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + wmean, + wsd, + ) + } } /// This function computes the weighted kurtosis of the dataset data. /// /// kurtosis = ((\sum w_i ((x_i - \Hat x)/\Hat \sigma)^4) / (\sum w_i)) - 3 #[doc(alias = "gsl_stats_wkurtosis")] -pub fn wkurtosis(w: &[f64], wstride: usize, data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_wkurtosis(w.as_ptr(), wstride, data.as_ptr(), stride, n) } +pub fn wkurtosis(w: &V, data: &V) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); + unsafe { + sys::gsl_stats_wkurtosis( + V::as_slice(w).as_ptr(), + V::stride(data), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function computes the weighted kurtosis of the dataset data using the given values of the /// weighted mean and weighted standard deviation, wmean and wsd. #[doc(alias = "gsl_stats_wkurtosis_m_sd")] -pub fn wkurtosis_m_sd( - w: &[f64], - wstride: usize, - data: &[f64], - stride: usize, - n: usize, - wmean: f64, - wsd: f64, -) -> f64 { +pub fn wkurtosis_m_sd(w: &V, data: &V, wmean: f64, wsd: f64) -> f64 +where + V: Vector + ?Sized, +{ + check_equal_len(w, data).unwrap(); unsafe { - sys::gsl_stats_wkurtosis_m_sd(w.as_ptr(), wstride, data.as_ptr(), stride, n, wmean, wsd) + sys::gsl_stats_wkurtosis_m_sd( + V::as_slice(w).as_ptr(), + V::stride(w), + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + wmean, + wsd, + ) } } @@ -491,8 +767,11 @@ pub fn wkurtosis_m_sd( /// If you want instead to find the element with the largest absolute magnitude you will need to /// apply fabs or abs to your data before calling this function. #[doc(alias = "gsl_stats_max")] -pub fn max(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_max(data.as_ptr(), stride, n) } +pub fn max(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_max(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function returns the minimum value in data, a dataset of length n with stride stride. The @@ -501,18 +780,32 @@ pub fn max(data: &[f64], stride: usize, n: usize) -> f64 { /// If you want instead to find the element with the smallest absolute magnitude you will need to /// apply fabs or abs to your data before calling this function. #[doc(alias = "gsl_stats_min")] -pub fn min(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_min(data.as_ptr(), stride, n) } +pub fn min(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_min(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function finds both the minimum and maximum values min, max in data in a single pass. /// /// Returns `(min, max)`. #[doc(alias = "gsl_stats_minmax")] -pub fn minmax(data: &[f64], stride: usize, n: usize) -> (f64, f64) { +pub fn minmax(data: &V) -> (f64, f64) +where + V: Vector + ?Sized, +{ let mut min = 0.; let mut max = 0.; - unsafe { sys::gsl_stats_minmax(&mut min, &mut max, data.as_ptr(), stride, n) }; + unsafe { + sys::gsl_stats_minmax( + &mut min, + &mut max, + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + }; (min, max) } @@ -520,16 +813,22 @@ pub fn minmax(data: &[f64], stride: usize, n: usize) -> (f64, f64) { /// stride. The maximum value is defined as the value of the element x_i which satisfies x_i >= x_j /// for all j. When there are several equal maximum elements then the first one is chosen. #[doc(alias = "gsl_stats_max_index")] -pub fn max_index(data: &[f64], stride: usize, n: usize) -> usize { - unsafe { sys::gsl_stats_max_index(data.as_ptr(), stride, n) } +pub fn max_index(data: &V) -> usize +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_max_index(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function returns the index of the minimum value in data, a dataset of length n with stride /// stride. The minimum value is defined as the value of the element x_i which satisfies x_i >= x_j /// for all j. When there are several equal minimum elements then the first one is chosen. #[doc(alias = "gsl_stats_min_index")] -pub fn min_index(data: &[f64], stride: usize, n: usize) -> usize { - unsafe { sys::gsl_stats_min_index(data.as_ptr(), stride, n) } +pub fn min_index(data: &V) -> usize +where + V: Vector + ?Sized, +{ + unsafe { sys::gsl_stats_min_index(V::as_slice(data).as_ptr(), V::stride(data), V::len(data)) } } /// This function returns the indexes min_index, max_index of the minimum and maximum values in data @@ -537,11 +836,20 @@ pub fn min_index(data: &[f64], stride: usize, n: usize) -> usize { /// /// Returns `(min_index, max_index)`. #[doc(alias = "gsl_stats_minmax_index")] -pub fn minmax_index(data: &[f64], stride: usize, n: usize) -> (usize, usize) { +pub fn minmax_index(data: &V) -> (usize, usize) +where + V: Vector + ?Sized, +{ let mut min_index = 0; let mut max_index = 0; unsafe { - sys::gsl_stats_minmax_index(&mut min_index, &mut max_index, data.as_ptr(), stride, n) + sys::gsl_stats_minmax_index( + &mut min_index, + &mut max_index, + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) }; (min_index, max_index) } @@ -555,8 +863,17 @@ pub fn minmax_index(data: &[f64], stride: usize, n: usize) -> (usize, usize) { /// values, elements (n-1)/2 and n/2. Since the algorithm for computing the median involves /// interpolation this function always returns a floating-point number, even for integer data types. #[doc(alias = "gsl_stats_median_from_sorted_data")] -pub fn median_from_sorted_data(data: &[f64], stride: usize, n: usize) -> f64 { - unsafe { sys::gsl_stats_median_from_sorted_data(data.as_ptr(), stride, n) } +pub fn median_from_sorted_data(data: &V) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_median_from_sorted_data( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + ) + } } /// This function returns a quantile value of sorted_data, a double-precision array of length n with @@ -578,6 +895,16 @@ pub fn median_from_sorted_data(data: &[f64], stride: usize, n: usize) -> f64 { /// to 0.5. Since the algorithm for computing quantiles involves interpolation this function always /// returns a floating-point number, even for integer data types. #[doc(alias = "gsl_stats_quantile_from_sorted_data")] -pub fn quantile_from_sorted_data(data: &[f64], stride: usize, n: usize, f: f64) -> f64 { - unsafe { sys::gsl_stats_quantile_from_sorted_data(data.as_ptr(), stride, n, f) } +pub fn quantile_from_sorted_data(data: &V, f: f64) -> f64 +where + V: Vector + ?Sized, +{ + unsafe { + sys::gsl_stats_quantile_from_sorted_data( + V::as_slice(data).as_ptr(), + V::stride(data), + V::len(data), + f, + ) + } } diff --git a/src/types/permutation.rs b/src/types/permutation.rs index b20cf7d2..5b03fb77 100644 --- a/src/types/permutation.rs +++ b/src/types/permutation.rs @@ -10,6 +10,8 @@ use crate::{MatrixComplexF32, MatrixComplexF64, MatrixF32, VectorF64}; use std::fmt::{self, Debug, Formatter}; use std::slice; +use super::vector::VectorMut; + // FIXME: Permutations have the same representation as vectors. // Do we want to wrap vectors? (The wrapping is to preserve invariants.) ffi_wrapper!(Permutation, *mut sys::gsl_permutation, gsl_permutation_free); @@ -163,20 +165,36 @@ impl Permutation { /// This function applies the permutation to the array data of size n with stride stride. #[doc(alias = "gsl_permute")] - pub fn permute(&mut self, data: &mut [f64], stride: usize) -> Result<(), Value> { + pub fn permute(&mut self, data: &mut V) -> Result<(), Value> + where + V: VectorMut + ?Sized, + { let ret = unsafe { let data_ptr = sys::gsl_permutation_data(self.unwrap_shared()); - sys::gsl_permute(data_ptr, data.as_mut_ptr(), stride, data.len() as _) + sys::gsl_permute( + data_ptr, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) }; result_handler!(ret, ()) } /// This function applies the inverse of the permutation p to the array data of size n with stride stride. #[doc(alias = "gsl_permute_inverse")] - pub fn permute_inverse(&mut self, data: &mut [f64], stride: usize) -> Result<(), Value> { + pub fn permute_inverse(&mut self, data: &mut V) -> Result<(), Value> + where + V: VectorMut + ?Sized, + { let ret = unsafe { let data_ptr = sys::gsl_permutation_data(self.unwrap_shared()); - sys::gsl_permute_inverse(data_ptr, data.as_mut_ptr(), stride, data.len() as _) + sys::gsl_permute_inverse( + data_ptr, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), + ) }; result_handler!(ret, ()) } diff --git a/src/wavelet_transforms.rs b/src/wavelet_transforms.rs index 0a132112..222d20a1 100644 --- a/src/wavelet_transforms.rs +++ b/src/wavelet_transforms.rs @@ -27,23 +27,22 @@ level of the transform. /// 2 or if insufficient workspace is provided. pub mod one_dimension { use crate::ffi::FFI; + use crate::vector::VectorMut; use crate::Value; #[doc(alias = "gsl_wavelet_transform")] - pub fn transform( + pub fn transform + ?Sized>( w: &crate::Wavelet, - data: &mut [f64], - stride: usize, - n: usize, + data: &mut V, dir: crate::WaveletDirection, work: &mut crate::WaveletWorkspace, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_wavelet_transform( w.unwrap_shared(), - data.as_mut_ptr(), - stride, - n, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), dir.into(), work.unwrap_unique(), ) @@ -52,19 +51,17 @@ pub mod one_dimension { } #[doc(alias = "gsl_wavelet_transform_forward")] - pub fn transform_forward( + pub fn transform_forward + ?Sized>( w: &crate::Wavelet, - data: &mut [f64], - stride: usize, - n: usize, + data: &mut V, work: &mut crate::WaveletWorkspace, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_wavelet_transform_forward( w.unwrap_shared(), - data.as_mut_ptr(), - stride, - n, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), work.unwrap_unique(), ) }; @@ -72,19 +69,17 @@ pub mod one_dimension { } #[doc(alias = "gsl_wavelet_transform_inverse")] - pub fn transform_inverse( + pub fn transform_inverse + ?Sized>( w: &crate::Wavelet, - data: &mut [f64], - stride: usize, - n: usize, + data: &mut V, work: &mut crate::WaveletWorkspace, ) -> Result<(), Value> { let ret = unsafe { sys::gsl_wavelet_transform_inverse( w.unwrap_shared(), - data.as_mut_ptr(), - stride, - n, + V::as_mut_slice(data).as_mut_ptr(), + V::stride(data), + V::len(data), work.unwrap_unique(), ) }; @@ -119,6 +114,7 @@ pub mod two_dimension { #[doc(alias = "gsl_wavelet2d_transform")] pub fn transform( w: &crate::Wavelet, + // FIXME: needs a Matrix data: &mut [f64], tda: usize, size1: usize,