diff --git a/base/array.jl b/base/array.jl index a6364188b8977..f247c8a32dbee 100644 --- a/base/array.jl +++ b/base/array.jl @@ -517,8 +517,25 @@ julia> collect(Float64, 1:2:5) """ collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr)) -_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr) -_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr) +function copyto_check_length!(dest::Array, src) + len = length(dest) + i = 0 + for x in src + i == len && + throw(ErrorException("iterator returned more elements than its declared length")) + i += 1 + @inbounds dest[i] = x + end + if i < len + throw(ErrorException("iterator returned fewer elements than its declared length")) + end + return dest +end + +_collect(::Type{T}, itr, isz::HasLength) where {T} = + copyto_check_length!(Vector{T}(undef, Int(length(itr)::Integer)), itr) +_collect(::Type{T}, itr, isz::HasShape) where {T} = + copyto_check_length!(similar(Array{T}, axes(itr)), itr) function _collect(::Type{T}, itr, isz::SizeUnknown) where T a = Vector{T}() for x in itr @@ -561,7 +578,7 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A) collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr)) _collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) = - copyto!(_similar_for(cont, eltype(itr), itr, isz), itr) + copyto_check_length!(_similar_for(cont, eltype(itr), itr, isz), itr) function _collect(cont, itr, ::HasEltype, isz::SizeUnknown) a = _similar_for(cont, eltype(itr), itr, isz) @@ -618,6 +635,9 @@ function collect(itr::Generator) else y = iterate(itr) if y === nothing + if isa(isz, Union{HasLength, HasShape}) && length(itr) != 0 + throw(ErrorException("iterator returned fewer elements than its declared length")) + end return _array_for(et, itr.iter, isz) end v1, st = y @@ -667,6 +687,10 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T return collect_to!(new, itr, i+1, st) end end + i-1 < length(dest) && + throw(ErrorException("iterator returned fewer elements than its declared length")) + i-1 > length(dest) && + throw(ErrorException("iterator returned more elements than its declared length")) return dest end diff --git a/test/arrayops.jl b/test/arrayops.jl index 20a0fb26e3963..c2d3469ac2118 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2524,3 +2524,29 @@ Base.view(::T25958, args...) = args @test t[end,end,1] == @view(t[end,end,1]) == @views t[end,end,1] @test t[end,end,end] == @view(t[end,end,end]) == @views t[end,end,end] end + +@testset "collect on iterator with incorrect length" begin + # Iterator with declared length too large + struct InvalidIter1 end + Base.length(::InvalidIter1) = 2 + Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1)) + + @test_throws ErrorException collect(InvalidIter1()) + @test_throws ErrorException collect(Any, InvalidIter1()) + @test_throws ErrorException collect(Int, InvalidIter1()) + @test_throws ErrorException [x for x in InvalidIter1()] + # Should also throw ErrorException + @test_broken length(Int[x for x in InvalidIter1()]) != 2 + + # Iterator with declared length too small + struct InvalidIter2 end + Base.length(::InvalidIter2) = 2 + Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1)) + + @test_throws ErrorException collect(InvalidIter2()) + @test_throws ErrorException collect(Any, InvalidIter2()) + @test_throws ErrorException collect(Int, InvalidIter2()) + @test_throws ErrorException [x for x in InvalidIter2()] + # Should also throw ErrorException + @test_broken length(Int[x for x in InvalidIter2()]) != 2 +end \ No newline at end of file