Skip to content

Commit

Permalink
Throw an error when collecting from iterators with inconsistent length
Browse files Browse the repository at this point in the history
Check that HasLength and HasShape iterators return the same number of elements
as indicated by length(itr). An error was already thrown in some situations when
the number of elements was higher than the declared length, but not all.
  • Loading branch information
nalimilan committed Oct 1, 2018
1 parent 201cba5 commit 3f4a404
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
30 changes: 27 additions & 3 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,25 @@ julia> collect(Float64, 1:2:5)
"""
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))

_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr)
function copyto_check_length!(dest::Array, src)
len = length(dest)
i = 0
for x in src
i == len &&
throw(ErrorException("iterator returned more elements than its declared length"))
i += 1
@inbounds dest[i] = x
end
if i < len
throw(ErrorException("iterator returned fewer elements than its declared length"))
end
return dest
end

_collect(::Type{T}, itr, isz::HasLength) where {T} =
copyto_check_length!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
_collect(::Type{T}, itr, isz::HasShape) where {T} =
copyto_check_length!(similar(Array{T}, axes(itr)), itr)
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
a = Vector{T}()
for x in itr
Expand Down Expand Up @@ -561,7 +578,7 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))

_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
copyto!(_similar_for(cont, eltype(itr), itr, isz), itr)
copyto_check_length!(_similar_for(cont, eltype(itr), itr, isz), itr)

function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
a = _similar_for(cont, eltype(itr), itr, isz)
Expand Down Expand Up @@ -618,6 +635,9 @@ function collect(itr::Generator)
else
y = iterate(itr)
if y === nothing
if isa(isz, Union{HasLength, HasShape}) && length(itr) != 0
throw(ErrorException("iterator returned fewer elements than its declared length"))
end
return _array_for(et, itr.iter, isz)
end
v1, st = y
Expand Down Expand Up @@ -667,6 +687,10 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T
return collect_to!(new, itr, i+1, st)
end
end
i-1 < length(dest) &&
throw(ErrorException("iterator returned fewer elements than its declared length"))
i-1 > length(dest) &&
throw(ErrorException("iterator returned more elements than its declared length"))
return dest
end

Expand Down
26 changes: 26 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2524,3 +2524,29 @@ Base.view(::T25958, args...) = args
@test t[end,end,1] == @view(t[end,end,1]) == @views t[end,end,1]
@test t[end,end,end] == @view(t[end,end,end]) == @views t[end,end,end]
end

@testset "collect on iterator with incorrect length" begin
# Iterator with declared length too large
struct InvalidIter1 end
Base.length(::InvalidIter1) = 2
Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1))

@test_throws ErrorException collect(InvalidIter1())
@test_throws ErrorException collect(Any, InvalidIter1())
@test_throws ErrorException collect(Int, InvalidIter1())
@test_throws ErrorException [x for x in InvalidIter1()]
# Should also throw ErrorException
@test_broken length(Int[x for x in InvalidIter1()]) != 2

# Iterator with declared length too small
struct InvalidIter2 end
Base.length(::InvalidIter2) = 2
Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1))

@test_throws ErrorException collect(InvalidIter2())
@test_throws ErrorException collect(Any, InvalidIter2())
@test_throws ErrorException collect(Int, InvalidIter2())
@test_throws ErrorException [x for x in InvalidIter2()]
# Should also throw ErrorException
@test_broken length(Int[x for x in InvalidIter2()]) != 2
end

0 comments on commit 3f4a404

Please sign in to comment.