Skip to content

Commit

Permalink
Use _copy_impl! instead of custom function
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan committed Oct 5, 2018
1 parent e811fa6 commit fd91297
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 33 deletions.
7 changes: 6 additions & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,9 @@ emptymutable(itr, ::Type{U}) where {U} = Vector{U}()

## from general iterable to any array

function copyto!(dest::AbstractArray, src)
copyto!(dest::AbstractArray, src) = _copyto_impl!(dest, src, true)

function _copyto_impl!(dest::AbstractArray, src, allowshorter::Bool)
destiter = eachindex(dest)
y = iterate(destiter)
for x in src
Expand All @@ -649,6 +651,9 @@ function copyto!(dest::AbstractArray, src)
dest[y[1]] = x
y = iterate(destiter, y[2])
end
if !allowshorter && y !== nothing
throw(ArgumentError(string("source has fewer elements than destination")))
end
return dest
end

Expand Down
39 changes: 16 additions & 23 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,24 @@ end
Copy `N` elements from collection `src` starting at offset `so`, to array `dest` starting at
offset `do`. Return `dest`.
"""
function copyto!(dest::Array{T}, doffs::Integer, src::Array{T}, soffs::Integer, n::Integer) where T
copyto!(dest::Array{T}, doffs::Integer, src::Array{T}, soffs::Integer, n::Integer) where {T} =
_copyto_impl!(dest, doffs, src, soffs, n, true)

function _copyto_impl!(dest::Array{T}, doffs::Integer, src::Array{T}, soffs::Integer, n::Integer,
allowshorter::Bool) where T
n == 0 && return dest
n > 0 || throw(ArgumentError(string("tried to copy n=", n, " elements, but n should be nonnegative")))
if soffs < 1 || doffs < 1 || soffs+n-1 > length(src) || doffs+n-1 > length(dest)
throw(BoundsError())
end
if !allowshorter && soffs+n-1 < length(dest)
throw(ArgumentError("source has fewer elements than destination"))
end
unsafe_copyto!(dest, doffs, src, soffs, n)
end

copyto!(dest::Array{T}, src::Array{T}) where {T} = copyto!(dest, 1, src, 1, length(src))
_copyto_impl!(dest::Array{T}, src::Array{T}, allowshorter::Bool) where {T} =
_copyto_impl!(dest, 1, src, 1, length(src), allowshorter)

# N.B: The generic definition in multidimensional.jl covers, this, this is just here
# for bootstrapping purposes.
Expand Down Expand Up @@ -517,25 +525,10 @@ julia> collect(Float64, 1:2:5)
"""
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))

function copyto_check_length!(dest::Array, src)
len = length(dest)
i = 0
for x in src
i == len &&
throw(ErrorException("iterator returned more elements than its declared length"))
i += 1
@inbounds dest[i] = x
end
if i < len
throw(ErrorException("iterator returned fewer elements than its declared length"))
end
return dest
end

_collect(::Type{T}, itr, isz::HasLength) where {T} =
copyto_check_length!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
_copyto_impl!(Vector{T}(undef, Int(length(itr)::Integer)), itr, false)
_collect(::Type{T}, itr, isz::HasShape) where {T} =
copyto_check_length!(similar(Array{T}, axes(itr)), itr)
_copyto_impl!(similar(Array{T}, axes(itr)), itr, false)
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
a = Vector{T}()
for x in itr
Expand Down Expand Up @@ -578,7 +571,7 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))

_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
copyto_check_length!(_similar_for(cont, eltype(itr), itr, isz), itr)
_copyto_impl!(_similar_for(cont, eltype(itr), itr, isz), itr, false)

function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
a = _similar_for(cont, eltype(itr), itr, isz)
Expand Down Expand Up @@ -636,7 +629,7 @@ function collect(itr::Generator)
y = iterate(itr)
if y === nothing
if isa(isz, Union{HasLength, HasShape}) && length(itr) != 0
throw(ErrorException("iterator returned fewer elements than its declared length"))
throw(ArgumentError("iterator returned fewer elements than its declared length"))
end
return _array_for(et, itr.iter, isz)
end
Expand Down Expand Up @@ -689,9 +682,9 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T
end
lastidx = lastindex(dest)
i-1 < lastidx &&
throw(ErrorException("iterator returned fewer elements than its declared length"))
throw(ArgumentError("iterator returned fewer elements than its declared length"))
i-1 > lastidx &&
throw(ErrorException("iterator returned more elements than its declared length"))
throw(ArgumentError("iterator returned more elements than its declared length"))
return dest
end

Expand Down
18 changes: 9 additions & 9 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2531,23 +2531,23 @@ end
Base.length(::InvalidIter1) = 2
Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1))

@test_throws ErrorException collect(InvalidIter1())
@test_throws ErrorException collect(Any, InvalidIter1())
@test_throws ErrorException collect(Int, InvalidIter1())
@test_throws ErrorException [x for x in InvalidIter1()]
# Should also throw ErrorException
@test_throws ArgumentError collect(InvalidIter1())
@test_throws ArgumentError collect(Any, InvalidIter1())
@test_throws ArgumentError collect(Int, InvalidIter1())
@test_throws ArgumentError [x for x in InvalidIter1()]
# Should also throw ArgumentError
@test_broken length(Int[x for x in InvalidIter1()]) != 2

# Iterator with declared length too small
struct InvalidIter2 end
Base.length(::InvalidIter2) = 2
Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1))

@test_throws ErrorException collect(InvalidIter2())
@test_throws ErrorException collect(Any, InvalidIter2())
@test_throws ErrorException collect(Int, InvalidIter2())
@test_throws ArgumentError collect(InvalidIter2())
@test_throws ArgumentError collect(Any, InvalidIter2())
@test_throws ArgumentError collect(Int, InvalidIter2())
# These cases cannot be tested without writing to invalid memory
# unless the function checked bounds on each iteration (#29458)
# @test_throws ErrorException [x for x in InvalidIter2()]
# @test_throws ErrorException Int[x for x in InvalidIter2()]
# @test_broken length(Int[x for x in InvalidIter2()]) != 2
end

0 comments on commit fd91297

Please sign in to comment.