From fd9129706a40706350106f8c8b9476b043aefc7c Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Fri, 5 Oct 2018 21:40:40 +0200 Subject: [PATCH] Use _copy_impl! instead of custom function --- base/abstractarray.jl | 7 ++++++- base/array.jl | 39 ++++++++++++++++----------------------- test/arrayops.jl | 18 +++++++++--------- 3 files changed, 31 insertions(+), 33 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 2aae8f80caffa..e5b60bc4f4fd1 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -640,7 +640,9 @@ emptymutable(itr, ::Type{U}) where {U} = Vector{U}() ## from general iterable to any array -function copyto!(dest::AbstractArray, src) +copyto!(dest::AbstractArray, src) = _copyto_impl!(dest, src, true) + +function _copyto_impl!(dest::AbstractArray, src, allowshorter::Bool) destiter = eachindex(dest) y = iterate(destiter) for x in src @@ -649,6 +651,9 @@ function copyto!(dest::AbstractArray, src) dest[y[1]] = x y = iterate(destiter, y[2]) end + if !allowshorter && y !== nothing + throw(ArgumentError(string("source has fewer elements than destination"))) + end return dest end diff --git a/base/array.jl b/base/array.jl index 10b2f35042787..7f5a0fbb06ab4 100644 --- a/base/array.jl +++ b/base/array.jl @@ -265,16 +265,24 @@ end Copy `N` elements from collection `src` starting at offset `so`, to array `dest` starting at offset `do`. Return `dest`. """ -function copyto!(dest::Array{T}, doffs::Integer, src::Array{T}, soffs::Integer, n::Integer) where T +copyto!(dest::Array{T}, doffs::Integer, src::Array{T}, soffs::Integer, n::Integer) where {T} = + _copyto_impl!(dest, doffs, src, soffs, n, true) + +function _copyto_impl!(dest::Array{T}, doffs::Integer, src::Array{T}, soffs::Integer, n::Integer, + allowshorter::Bool) where T n == 0 && return dest n > 0 || throw(ArgumentError(string("tried to copy n=", n, " elements, but n should be nonnegative"))) if soffs < 1 || doffs < 1 || soffs+n-1 > length(src) || doffs+n-1 > length(dest) throw(BoundsError()) end + if !allowshorter && soffs+n-1 < length(dest) + throw(ArgumentError("source has fewer elements than destination")) + end unsafe_copyto!(dest, doffs, src, soffs, n) end -copyto!(dest::Array{T}, src::Array{T}) where {T} = copyto!(dest, 1, src, 1, length(src)) +_copyto_impl!(dest::Array{T}, src::Array{T}, allowshorter::Bool) where {T} = + _copyto_impl!(dest, 1, src, 1, length(src), allowshorter) # N.B: The generic definition in multidimensional.jl covers, this, this is just here # for bootstrapping purposes. @@ -517,25 +525,10 @@ julia> collect(Float64, 1:2:5) """ collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr)) -function copyto_check_length!(dest::Array, src) - len = length(dest) - i = 0 - for x in src - i == len && - throw(ErrorException("iterator returned more elements than its declared length")) - i += 1 - @inbounds dest[i] = x - end - if i < len - throw(ErrorException("iterator returned fewer elements than its declared length")) - end - return dest -end - _collect(::Type{T}, itr, isz::HasLength) where {T} = - copyto_check_length!(Vector{T}(undef, Int(length(itr)::Integer)), itr) + _copyto_impl!(Vector{T}(undef, Int(length(itr)::Integer)), itr, false) _collect(::Type{T}, itr, isz::HasShape) where {T} = - copyto_check_length!(similar(Array{T}, axes(itr)), itr) + _copyto_impl!(similar(Array{T}, axes(itr)), itr, false) function _collect(::Type{T}, itr, isz::SizeUnknown) where T a = Vector{T}() for x in itr @@ -578,7 +571,7 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A) collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr)) _collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) = - copyto_check_length!(_similar_for(cont, eltype(itr), itr, isz), itr) + _copyto_impl!(_similar_for(cont, eltype(itr), itr, isz), itr, false) function _collect(cont, itr, ::HasEltype, isz::SizeUnknown) a = _similar_for(cont, eltype(itr), itr, isz) @@ -636,7 +629,7 @@ function collect(itr::Generator) y = iterate(itr) if y === nothing if isa(isz, Union{HasLength, HasShape}) && length(itr) != 0 - throw(ErrorException("iterator returned fewer elements than its declared length")) + throw(ArgumentError("iterator returned fewer elements than its declared length")) end return _array_for(et, itr.iter, isz) end @@ -689,9 +682,9 @@ function collect_to!(dest::AbstractArray{T}, itr, offs, st) where T end lastidx = lastindex(dest) i-1 < lastidx && - throw(ErrorException("iterator returned fewer elements than its declared length")) + throw(ArgumentError("iterator returned fewer elements than its declared length")) i-1 > lastidx && - throw(ErrorException("iterator returned more elements than its declared length")) + throw(ArgumentError("iterator returned more elements than its declared length")) return dest end diff --git a/test/arrayops.jl b/test/arrayops.jl index dfcb774beb85f..b99a80ba2ef56 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -2531,11 +2531,11 @@ end Base.length(::InvalidIter1) = 2 Base.iterate(::InvalidIter1, i=1) = i > 1 ? nothing : (i, (i + 1)) - @test_throws ErrorException collect(InvalidIter1()) - @test_throws ErrorException collect(Any, InvalidIter1()) - @test_throws ErrorException collect(Int, InvalidIter1()) - @test_throws ErrorException [x for x in InvalidIter1()] - # Should also throw ErrorException + @test_throws ArgumentError collect(InvalidIter1()) + @test_throws ArgumentError collect(Any, InvalidIter1()) + @test_throws ArgumentError collect(Int, InvalidIter1()) + @test_throws ArgumentError [x for x in InvalidIter1()] + # Should also throw ArgumentError @test_broken length(Int[x for x in InvalidIter1()]) != 2 # Iterator with declared length too small @@ -2543,11 +2543,11 @@ end Base.length(::InvalidIter2) = 2 Base.iterate(::InvalidIter2, i=1) = i > 3 ? nothing : (i, (i + 1)) - @test_throws ErrorException collect(InvalidIter2()) - @test_throws ErrorException collect(Any, InvalidIter2()) - @test_throws ErrorException collect(Int, InvalidIter2()) + @test_throws ArgumentError collect(InvalidIter2()) + @test_throws ArgumentError collect(Any, InvalidIter2()) + @test_throws ArgumentError collect(Int, InvalidIter2()) # These cases cannot be tested without writing to invalid memory # unless the function checked bounds on each iteration (#29458) # @test_throws ErrorException [x for x in InvalidIter2()] - # @test_throws ErrorException Int[x for x in InvalidIter2()] + # @test_broken length(Int[x for x in InvalidIter2()]) != 2 end \ No newline at end of file