From 3f4a4047625a2f01be8cfe31b39a86c27a7e6c55 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Mon, 1 Oct 2018 20:45:16 +0200 Subject: [PATCH] Throw an error when collecting from iterators with inconsistent length Check that HasLength and HasShape iterators return the same number of elements as indicated by length(itr). An error was already thrown in some situations when the number of elements was higher than the declared length, but not all. --- base/array.jl | 30 +++++++++++++++++++++++++++--- test/arrayops.jl | 26 ++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/base/array.jl b/base/array.jl index a6364188b8977..f247c8a32dbee 100644 --- a/base/array.jl +++ b/base/array.jl @@ -517,8 +517,25 @@ julia> collect(Float64, 1:2:5) """ collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr)) -_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr) -_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr) +function copyto_check_length!(dest::Array, src) + len = length(dest) + i = 0 + for x in src + i == len && + throw(ErrorException("iterator returned more elements than its declared length")) + i += 1 + @inbounds dest[i] = x + end + if i < len + throw(ErrorException("iterator returned fewer elements than its declared length")) + end + return dest +end + +_collect(::Type{T}, itr, isz::HasLength) where {T} = + copyto_check_length!(Vector{T}(undef, Int(length(itr)::Integer)), itr) +_collect(::Type{T}, itr, isz::HasShape) where {T} = + copyto_check_length!(similar(Array{T}, axes(itr)), itr) function _collect(::Type{T}, itr, isz::SizeUnknown) where T a = Vector{T}() for x in itr @@ -561,7 +578,7 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A) collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr)) _collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) = - copyto!(_similar_for(cont, eltype(itr), itr, isz), itr) + copyto_check_length!(_similar_for(cont, eltype(itr), itr, isz), itr) function _collect(cont, itr, ::HasEltype, isz::SizeUnknown) a = _similar_for(cont, eltype(itr), itr, isz) @@ -618,6 +635,9 @@ function collect(itr::Generator) else y = iterate(itr) if y === nothing + if isa(isz, Union{HasLength, HasShape}) && length(itr) != 0 + throw(ErrorException("iterator returned fewer elements than its declared length")) + end return _array_for(et, itr.iter, isz) end v1, st = y @@ -667,6 +687,10 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T return collect_to!(new, itr, i+1, st) end end + i-1 < length(dest) && + throw(ErrorException("iterator returned fewer elements than its declared length")) + i-1 > length(dest) && + throw(ErrorException("iterator returned more elements than its declared length")) return dest end diff --git a/test/arrayops.jl b/test/arrayops.jl index 20a0fb26e3963..c2d3469ac2118 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2524,3 +2524,29 @@ Base.view(::T25958, args...) = args @test t[end,end,1] == @view(t[end,end,1]) == @views t[end,end,1] @test t[end,end,end] == @view(t[end,end,end]) == @views t[end,end,end] end + +@testset "collect on iterator with incorrect length" begin + # Iterator with declared length too large + struct InvalidIter1 end + Base.length(::InvalidIter1) = 2 + Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1)) + + @test_throws ErrorException collect(InvalidIter1()) + @test_throws ErrorException collect(Any, InvalidIter1()) + @test_throws ErrorException collect(Int, InvalidIter1()) + @test_throws ErrorException [x for x in InvalidIter1()] + # Should also throw ErrorException + @test_broken length(Int[x for x in InvalidIter1()]) != 2 + + # Iterator with declared length too small + struct InvalidIter2 end + Base.length(::InvalidIter2) = 2 + Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1)) + + @test_throws ErrorException collect(InvalidIter2()) + @test_throws ErrorException collect(Any, InvalidIter2()) + @test_throws ErrorException collect(Int, InvalidIter2()) + @test_throws ErrorException [x for x in InvalidIter2()] + # Should also throw ErrorException + @test_broken length(Int[x for x in InvalidIter2()]) != 2 +end \ No newline at end of file