From 3f4a4047625a2f01be8cfe31b39a86c27a7e6c55 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Mon, 1 Oct 2018 20:45:16 +0200 Subject: [PATCH 1/3] Throw an error when collecting from iterators with inconsistent length Check that HasLength and HasShape iterators return the same number of elements as indicated by length(itr). An error was already thrown in some situations when the number of elements was higher than the declared length, but not all. --- base/array.jl | 30 +++++++++++++++++++++++++++--- test/arrayops.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/base/array.jl b/base/array.jl index a6364188b8977..f247c8a32dbee 100644 --- a/base/array.jl +++ b/base/array.jl @@ -517,8 +517,25 @@ julia> collect(Float64, 1:2:5) """ collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr)) -_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr) -_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr) +function copyto_check_length!(dest::Array, src) + len = length(dest) + i = 0 + for x in src + i == len && + throw(ErrorException("iterator returned more elements than its declared length")) + i += 1 + @inbounds dest[i] = x + end + if i < len + throw(ErrorException("iterator returned fewer elements than its declared length")) + end + return dest +end + +_collect(::Type{T}, itr, isz::HasLength) where {T} = + copyto_check_length!(Vector{T}(undef, Int(length(itr)::Integer)), itr) +_collect(::Type{T}, itr, isz::HasShape) where {T} = + copyto_check_length!(similar(Array{T}, axes(itr)), itr) function _collect(::Type{T}, itr, isz::SizeUnknown) where T a = Vector{T}() for x in itr @@ -561,7 +578,7 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A) collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr)) _collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) = - copyto!(_similar_for(cont, eltype(itr), itr, isz), itr) + copyto_check_length!(_similar_for(cont, eltype(itr), itr, isz), itr) function _collect(cont, itr, ::HasEltype, isz::SizeUnknown) a = _similar_for(cont, eltype(itr), itr, isz) @@ -618,6 +635,9 @@ function collect(itr::Generator) else y = iterate(itr) if y === nothing + if isa(isz, Union{HasLength, HasShape}) && length(itr) != 0 + throw(ErrorException("iterator returned fewer elements than its declared length")) + end return _array_for(et, itr.iter, isz) end v1, st = y @@ -667,6 +687,10 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T return collect_to!(new, itr, i+1, st) end end + i-1 < length(dest) && + throw(ErrorException("iterator returned fewer elements than its declared length")) + i-1 > length(dest) && + throw(ErrorException("iterator returned more elements than its declared length")) return dest end diff --git a/test/arrayops.jl b/test/arrayops.jl index 20a0fb26e3963..c2d3469ac2118 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2524,3 +2524,29 @@ Base.view(::T25958, args...) = args @test t[end,end,1] == @view(t[end,end,1]) == @views t[end,end,1] @test t[end,end,end] == @view(t[end,end,end]) == @views t[end,end,end] end + +@testset "collect on iterator with incorrect length" begin + # Iterator with declared length too large + struct InvalidIter1 end + Base.length(::InvalidIter1) = 2 + Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1)) + + @test_throws ErrorException collect(InvalidIter1()) + @test_throws ErrorException collect(Any, InvalidIter1()) + @test_throws ErrorException collect(Int, InvalidIter1()) + @test_throws ErrorException [x for x in InvalidIter1()] + # Should also throw ErrorException + @test_broken length(Int[x for x in InvalidIter1()]) != 2 + + # Iterator with declared length too small + struct InvalidIter2 end + Base.length(::InvalidIter2) = 2 + Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1)) + + @test_throws ErrorException collect(InvalidIter2()) + @test_throws ErrorException collect(Any, InvalidIter2()) + @test_throws ErrorException collect(Int, InvalidIter2()) + @test_throws ErrorException [x for x in InvalidIter2()] + # Should also throw ErrorException + @test_broken length(Int[x for x in InvalidIter2()]) != 2 +end \ No newline at end of file From e811fa67fbad09d7e7facab5102bfc8ca7a94d20 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Fri, 5 Oct 2018 14:12:58 +0200 Subject: [PATCH 2/3] Review fixes --- base/array.jl | 5 +++-- test/arrayops.jl | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/base/array.jl b/base/array.jl index f247c8a32dbee..10b2f35042787 100644 --- a/base/array.jl +++ b/base/array.jl @@ -687,9 +687,10 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T return collect_to!(new, itr, i+1, st) end end - i-1 < length(dest) && + lastidx = lastindex(dest) + i-1 < lastidx && throw(ErrorException("iterator returned fewer elements than its declared length")) - i-1 > length(dest) && + i-1 > lastidx && throw(ErrorException("iterator returned more elements than its declared length")) return dest end diff --git a/test/arrayops.jl b/test/arrayops.jl index c2d3469ac2118..dfcb774beb85f 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2546,7 +2546,8 @@ end @test_throws ErrorException collect(InvalidIter2()) @test_throws ErrorException collect(Any, InvalidIter2()) @test_throws ErrorException collect(Int, InvalidIter2()) - @test_throws ErrorException [x for x in InvalidIter2()] - # Should also throw ErrorException - @test_broken length(Int[x for x in InvalidIter2()]) != 2 + # These cases cannot be tested without writing to invalid memory + # unless the function checked bounds on each iteration (#29458) + # @test_throws ErrorException [x for x in InvalidIter2()] + # @test_throws ErrorException Int[x for x in InvalidIter2()] end \ No newline at end of file From c7b6896df9e92420ffa81bca70ba0f1422f05a9e Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Fri, 5 Oct 2018 21:40:40 +0200 Subject: [PATCH 3/3] Use _copy_impl! instead of custom function --- base/abstractarray.jl | 13 +++++++++++-- base/array.jl | 34 ++++++++++++---------------------- base/multidimensional.jl | 6 +++++- test/arrayops.jl | 36 +++++++++++++++++------------------- 4 files changed, 45 insertions(+), 44 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 2aae8f80caffa..4d3a1a83ff36e 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -640,7 +640,9 @@ emptymutable(itr, ::Type{U}) where {U} = Vector{U}() ## from general iterable to any array -function copyto!(dest::AbstractArray, src) +copyto!(dest::AbstractArray, src) = _copyto_impl!(dest, src, true) + +function _copyto_impl!(dest::AbstractArray, src, allowshorter::Bool) destiter = eachindex(dest) y = iterate(destiter) for x in src @@ -649,6 +651,9 @@ function copyto!(dest::AbstractArray, src) dest[y[1]] = x y = iterate(destiter, y[2]) end + if !allowshorter && y !== nothing + throw(ArgumentError(string("source has fewer elements than destination"))) + end return dest end @@ -720,8 +725,12 @@ end ## copy between abstract arrays - generally more efficient ## since a single index variable can be used. -copyto!(dest::AbstractArray, src::AbstractArray) = +function _copyto_impl!(dest::AbstractArray, src::AbstractArray, allowshorter::Bool) + if !allowshorter && length(src) < length(dest) + throw(ArgumentError("source has fewer elements than destination")) + end copyto!(IndexStyle(dest), dest, IndexStyle(src), src) +end function copyto!(::IndexStyle, dest::AbstractArray, ::IndexStyle, src::AbstractArray) destinds, srcinds = LinearIndices(dest), LinearIndices(src) diff --git a/base/array.jl b/base/array.jl index 10b2f35042787..ba9ee844630b9 100644 --- a/base/array.jl +++ b/base/array.jl @@ -274,7 +274,12 @@ function copyto!(dest::Array{T}, doffs::Integer, src::Array{T}, soffs::Integer, unsafe_copyto!(dest, doffs, src, soffs, n) end -copyto!(dest::Array{T}, src::Array{T}) where {T} = copyto!(dest, 1, src, 1, length(src)) +function _copyto_impl!(dest::Array{T}, src::Array{T}, allowshorter::Bool) where {T} + if !allowshorter && length(src) < length(dest) + throw(ArgumentError("source has fewer elements than destination")) + end + copyto!(dest, 1, src, 1, length(src)) +end # N.B: The generic definition in multidimensional.jl covers, this, this is just here # for bootstrapping purposes. @@ -517,25 +522,10 @@ julia> collect(Float64, 1:2:5) """ collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr)) -function copyto_check_length!(dest::Array, src) - len = length(dest) - i = 0 - for x in src - i == len && - throw(ErrorException("iterator returned more elements than its declared length")) - i += 1 - @inbounds dest[i] = x - end - if i < len - throw(ErrorException("iterator returned fewer elements than its declared length")) - end - return dest -end - _collect(::Type{T}, itr, isz::HasLength) where {T} = - copyto_check_length!(Vector{T}(undef, Int(length(itr)::Integer)), itr) + _copyto_impl!(Vector{T}(undef, Int(length(itr)::Integer)), itr, false) _collect(::Type{T}, itr, isz::HasShape) where {T} = - copyto_check_length!(similar(Array{T}, axes(itr)), itr) + _copyto_impl!(similar(Array{T}, axes(itr)), itr, false) function _collect(::Type{T}, itr, isz::SizeUnknown) where T a = Vector{T}() for x in itr @@ -578,7 +568,7 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A) collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr)) _collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) = - copyto_check_length!(_similar_for(cont, eltype(itr), itr, isz), itr) + _copyto_impl!(_similar_for(cont, eltype(itr), itr, isz), itr, false) function _collect(cont, itr, ::HasEltype, isz::SizeUnknown) a = _similar_for(cont, eltype(itr), itr, isz) @@ -636,7 +626,7 @@ function collect(itr::Generator) y = iterate(itr) if y === nothing if isa(isz, Union{HasLength, HasShape}) && length(itr) != 0 - throw(ErrorException("iterator returned fewer elements than its declared length")) + throw(ArgumentError("iterator returned fewer elements than its declared length")) end return _array_for(et, itr.iter, isz) end @@ -689,9 +679,9 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T end lastidx = lastindex(dest) i-1 < lastidx && - throw(ErrorException("iterator returned fewer elements than its declared length")) + throw(ArgumentError("iterator returned fewer elements than its declared length")) i-1 > lastidx && - throw(ErrorException("iterator returned more elements than its declared length")) + throw(ArgumentError("iterator returned more elements than its declared length")) return dest end diff --git a/base/multidimensional.jl b/base/multidimensional.jl index 1a78b84c1e6a3..5947a16add827 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -815,8 +815,12 @@ julia> y """ copyto!(dest, src) -function copyto!(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) where {T,N} +function _copyto_impl!(dest::AbstractArray{T,N}, src::AbstractArray{T,N}, + allowshorter::Bool) where {T,N} checkbounds(dest, axes(src)...) + if !allowshorter && length(src) < length(dest) + throw(ArgumentError("source has fewer elements than destination")) + end src′ = unalias(dest, src) for I in eachindex(IndexStyle(src′,dest), src′) @inbounds dest[I] = src′[I] diff --git a/test/arrayops.jl b/test/arrayops.jl index dfcb774beb85f..c708a6dfd4ee0 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2525,29 +2525,27 @@ Base.view(::T25958, args...) = args @test t[end,end,end] == @view(t[end,end,end]) == @views t[end,end,end] end +# Iterator with declared length too large +struct InvalidIter1 end +Base.length(::InvalidIter1) = 2 +Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1)) +# Iterator with declared length too small +struct InvalidIter2 end +Base.length(::InvalidIter2) = 2 +Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1)) @testset "collect on iterator with incorrect length" begin - # Iterator with declared length too large - struct InvalidIter1 end - Base.length(::InvalidIter1) = 2 - Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1)) - - @test_throws ErrorException collect(InvalidIter1()) - @test_throws ErrorException collect(Any, InvalidIter1()) - @test_throws ErrorException collect(Int, InvalidIter1()) - @test_throws ErrorException [x for x in InvalidIter1()] - # Should also throw ErrorException + @test_throws ArgumentError collect(InvalidIter1()) + @test_throws ArgumentError collect(Any, InvalidIter1()) + @test_throws ArgumentError collect(Int, InvalidIter1()) + @test_throws ArgumentError [x for x in InvalidIter1()] + # Should also throw ArgumentError @test_broken length(Int[x for x in InvalidIter1()]) != 2 - # Iterator with declared length too small - struct InvalidIter2 end - Base.length(::InvalidIter2) = 2 - Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1)) - - @test_throws ErrorException collect(InvalidIter2()) - @test_throws ErrorException collect(Any, InvalidIter2()) - @test_throws ErrorException collect(Int, InvalidIter2()) + @test_throws ArgumentError collect(InvalidIter2()) + @test_throws ArgumentError collect(Any, InvalidIter2()) + @test_throws ArgumentError collect(Int, InvalidIter2()) # These cases cannot be tested without writing to invalid memory # unless the function checked bounds on each iteration (#29458) # @test_throws ErrorException [x for x in InvalidIter2()] - # @test_throws ErrorException Int[x for x in InvalidIter2()] + # @test_broken length(Int[x for x in InvalidIter2()]) != 2 end \ No newline at end of file