diff --git a/base/int.jl b/base/int.jl index 8a80f90f7e2c1..01cf24c009fe3 100644 --- a/base/int.jl +++ b/base/int.jl @@ -96,6 +96,44 @@ inv(x::Integer) = float(one(x)) / float(x) # skip promotion for system integer types (/)(x::BitInteger, y::BitInteger) = float(x) / float(y) + +""" + mul_hi(a::T, b::T) where {T<:Base.BitInteger} + +Returns the higher half of the product of `a` and `b`. + +# Examples +```jldoctest +julia> Base.mul_hi(12345678987654321, 123456789) +82624 + +julia> (widen(12345678987654321) * 123456789) >> 64 +82624 + +julia> Base.mul_hi(0xff, 0xff) +0xfe +``` +""" +function mul_hi(a::T, b::T) where {T<:BitInteger} + ((widen(a)*b) >>> (sizeof(a)*8)) % T +end + +function mul_hi(a::UInt128, b::UInt128) + shift = sizeof(a)*4 + mask = typemax(UInt128) >> shift + a1, a2 = a >>> shift, a & mask + b1, b2 = b >>> shift, b & mask + a1b1, a1b2, a2b1, a2b2 = a1*b1, a1*b2, a2*b1, a2*b2 + carry = ((a1b2 & mask) + (a2b1 & mask) + (a2b2 >>> shift)) >>> shift + a1b1 + (a1b2 >>> shift) + (a2b1 >>> shift) + carry +end + +function mul_hi(a::Int128, b::Int128) + shift = sizeof(a)*8 - 1 + t1, t2 = (a >> shift) & b % UInt128, (b >> shift) & a % UInt128 + (mul_hi(a % UInt128, b % UInt128) - t1 - t2) % Int128 +end + """ isodd(x::Number) -> Bool diff --git a/base/multinverses.jl b/base/multinverses.jl index 70033de12fcd8..d1b455c736009 100644 --- a/base/multinverses.jl +++ b/base/multinverses.jl @@ -2,7 +2,7 @@ module MultiplicativeInverses -import Base: div, divrem, rem, unsigned +import Base: div, divrem, mul_hi, rem, unsigned using Base: IndexLinear, IndexCartesian, tail export multiplicativeinverse @@ -134,33 +134,13 @@ struct UnsignedMultiplicativeInverse{T<:Unsigned} <: MultiplicativeInverse{T} end UnsignedMultiplicativeInverse(x::Unsigned) = UnsignedMultiplicativeInverse{typeof(x)}(x) -# Returns the higher half of the product a*b -function _mul_high(a::T, b::T) where {T<:Union{Signed, Unsigned}} - ((widen(a)*b) >>> (sizeof(a)*8)) % T -end - -function _mul_high(a::UInt128, b::UInt128) - shift = sizeof(a)*4 - mask = typemax(UInt128) >> shift - a1, a2 = a >>> shift, a & mask - b1, b2 = b >>> shift, b & mask - a1b1, a1b2, a2b1, a2b2 = a1*b1, a1*b2, a2*b1, a2*b2 - carry = ((a1b2 & mask) + (a2b1 & mask) + (a2b2 >>> shift)) >>> shift - a1b1 + (a1b2 >>> shift) + (a2b1 >>> shift) + carry -end -function _mul_high(a::Int128, b::Int128) - shift = sizeof(a)*8 - 1 - t1, t2 = (a >> shift) & b % UInt128, (b >> shift) & a % UInt128 - (_mul_high(a % UInt128, b % UInt128) - t1 - t2) % Int128 -end - function div(a::T, b::SignedMultiplicativeInverse{T}) where T - x = _mul_high(a, b.multiplier) + x = mul_hi(a, b.multiplier) x += (a*b.addmul) % T ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % T) end function div(a::T, b::UnsignedMultiplicativeInverse{T}) where T - x = _mul_high(a, b.multiplier) + x = mul_hi(a, b.multiplier) x = ifelse(b.add, convert(T, convert(T, (convert(T, a - x) >>> 1)) + x), x) ifelse(b.divisor == 1, a, x >>> b.shift) end diff --git a/base/public.jl b/base/public.jl index 8777a454c920a..71fd31d6607f3 100644 --- a/base/public.jl +++ b/base/public.jl @@ -71,6 +71,9 @@ public isoperator, isunaryoperator, +# Integer math + mul_hi, + # C interface cconvert, unsafe_convert, diff --git a/test/numbers.jl b/test/numbers.jl index dc4f2cb613d77..fd45dd8793e91 100644 --- a/test/numbers.jl +++ b/test/numbers.jl @@ -2523,6 +2523,22 @@ Base.:(==)(x::TestNumber, y::TestNumber) = x.inner == y.inner Base.abs(x::TestNumber) = TestNumber(abs(x.inner)) @test abs2(TestNumber(3+4im)) == TestNumber(25) +@testset "mul_hi" begin + n = 1000 + ground_truth(x, y) = ((widen(x)*y) >> (8*sizeof(typeof(x)))) % typeof(x) + for T in [UInt8, UInt16, UInt32, UInt64, UInt128, Int8, Int16, Int32, Int64, Int128] + for trait1 in [typemin, typemax] + for trait2 in [typemin, typemax] + x, y = trait1(T), trait2(T) + @test Base.mul_hi(x, y) === ground_truth(x, y) + end + end + for (x, y) in zip(rand(T, n), rand(T, n)) + @test Base.mul_hi(x, y) === ground_truth(x, y) + end + end +end + @testset "multiplicative inverses" begin function testmi(numrange, denrange) for d in denrange