Skip to content

Commit

Permalink
Small fixups to #600 (#601)
Browse files Browse the repository at this point in the history
* Fix multi-column `_small_filt_fir!` for `ndims(x)>2`

Missed in #600.

* Consistently require mutli-column state to match input layout

* `CartesianIndex()`

likely to lower into better code

---------

Co-authored-by: wheeheee <[email protected]>
  • Loading branch information
martinholters and wheeheee authored Dec 3, 2024
1 parent 1be88b5 commit be65e83
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
11 changes: 4 additions & 7 deletions src/Filters/filt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ end
function filt!(out::AbstractArray, f::SecondOrderSections{:z}, x::AbstractArray,
si::AbstractArray{S,N}=_zerosi(f, x)) where {S,N}
biquads = f.biquads
ncols = Base.trailingsize(x, 2)

size(x) != size(out) && throw(DimensionMismatch("out size must match x"))
(size(si, 1) != 2 || size(si, 2) != length(biquads) || (N > 2 && Base.trailingsize(si, 3) != ncols)) &&
(size(si, 1) != 2 || size(si, 2) != length(biquads) || (N > 2 && size(si)[3:end] != size(x)[2:end])) &&
throw(ArgumentError("si must be 2 x nbiquads or 2 x nbiquads x nsignals"))

initial_si = si
si = similar(si, axes(si)[1:2])
for col in CartesianIndices(axes(x)[2:end])
copyto!(si, view(initial_si, :, :, N > 2 ? col : 1))
copyto!(si, view(initial_si, :, :, N > 2 ? col : CartesianIndex()))
_filt!(out, si, f, x, col)
end
out
Expand Down Expand Up @@ -99,14 +98,12 @@ end
# filt! variant that preserves si
function filt!(out::AbstractArray, f::Biquad{:z}, x::AbstractArray,
si::AbstractArray{S,N}=_zerosi(f, x)) where {S,N}
ncols = Base.trailingsize(x, 2)

size(x) != size(out) && throw(DimensionMismatch("out size must match x"))
(size(si, 1) != 2 || (N > 1 && Base.trailingsize(si, 2) != ncols)) &&
(size(si, 1) != 2 || (N > 1 && size(si)[2:end] != size(x)[2:end])) &&
throw(ArgumentError("si must have two rows and 1 or nsignals columns"))

for col in CartesianIndices(axes(x)[2:end])
_filt!(out, si[1, N > 1 ? col : 1], si[2, N > 1 ? col : 1], f, x, col)
_filt!(out, si[1, N > 1 ? col : CartesianIndex()], si[2, N > 1 ? col : CartesianIndex()], f, x, col)
end
out
end
Expand Down
9 changes: 4 additions & 5 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ function filt!(out::AbstractArray, b::Union{AbstractVector, Number}, a::Union{Ab
bs = length(b)
sz = max(as, bs)
silen = sz - 1
ncols = size(x, 2)

if size(si, 1) != silen
throw(ArgumentError("initial state vector si must have max(length(a),length(b))-1 rows"))
elseif N > 1 && size(si, 2) != ncols
elseif N > 1 && size(si)[2:end] != size(x)[2:end]
throw(ArgumentError("initial state si must be a vector or have the same number of columns as x"))
end

Expand All @@ -70,7 +69,7 @@ function filt!(out::AbstractArray, b::Union{AbstractVector, Number}, a::Union{Ab
si = similar(si, axes(si, 1))
for col in CartesianIndices(axes(x)[2:end])
# Reset the filter state
copyto!(si, view(initial_si, :, N > 1 ? col : 1))
copyto!(si, view(initial_si, :, N > 1 ? col : CartesianIndex()))
if as > 1
_filt_iir!(out, b, a, x, si, col)
else
Expand Down Expand Up @@ -124,7 +123,7 @@ const SMALL_FILT_VECT_CUTOFF = 19
si_end = Symbol(:si_, silen)

quote
col = colv isa Val{:DF2} ? 1 : colv
col = colv isa Val{:DF2} ? CartesianIndex() : colv
N <= SMALL_FILT_VECT_CUTOFF && checkbounds(siarr, $silen)
Base.@nextract $silen si siarr
for i in axes(x, 1)
Expand Down Expand Up @@ -152,7 +151,7 @@ function _small_filt_fir!(
bs < 2 && throw(ArgumentError("invalid tuple size"))
length(h) != bs && throw(ArgumentError("length(h) does not match bs"))
b = ntuple(j -> h[j], Val(bs))
for col in axes(x, 2)
for col in CartesianIndices(axes(x)[2:end])
v_si = N > 1 ? view(si, :, col) : si
_filt_fir!(out, b, x, v_si, col)
end
Expand Down
20 changes: 16 additions & 4 deletions test/filt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,22 @@ end
sz = (10, ntuple(n -> n+1, Val(D))...)
y_ref = filt(b, a, ones(sz[1]))
x = ones(sz)
@test all(col -> col y_ref, eachslice(filt(b, a, x); dims=ntuple(n -> n+1, Val(D))))
@test all(col -> col y_ref, eachslice(filt(PolynomialRatio(b, a), x); dims=ntuple(n -> n+1, Val(D))))
@test all(col -> col y_ref, eachslice(filt(Biquad(PolynomialRatio(b, a)), x); dims=ntuple(n -> n+1, Val(D))))
@test all(col -> col y_ref, eachslice(filt(SecondOrderSections(PolynomialRatio(b, a)), x); dims=ntuple(n -> n+1, Val(D))))
slicedims = ntuple(n -> n+1, Val(D))
@test all(col -> col y_ref, eachslice(filt(b, a, x); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(PolynomialRatio(b, a), x); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(Biquad(PolynomialRatio(b, a)), x); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(SecondOrderSections(PolynomialRatio(b, a)), x); dims=slicedims))
# with si given
@test all(col -> col y_ref, eachslice(filt(b, a, x, zeros(1, sz[2:end]...)); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(PolynomialRatio(b, a), x, zeros(1, sz[2:end]...)); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(Biquad(PolynomialRatio(b, a)), x, zeros(2, sz[2:end]...)); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(SecondOrderSections(PolynomialRatio(b, a)), x, zeros(2, 1, sz[2:end]...)); dims=slicedims))
# use _small_filt_fir!
b = [0.1, 0.1]
a = [1.0]
y_ref = filt(b, a, ones(sz[1]))
@test all(col -> col y_ref, eachslice(filt(b, a, x); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(PolynomialRatio(b, a), x); dims=slicedims))
end

#
Expand Down

0 comments on commit be65e83

Please sign in to comment.