From af5ae5f6e1e52170ff810b00dc3e5a1c22685d01 Mon Sep 17 00:00:00 2001 From: Will Tebbutt <wt0881@my.bristol.ac.uk> Date: Sat, 29 Jul 2023 13:10:31 +0100 Subject: [PATCH] Boundscheck implementation + test --- Project.toml | 2 +- src/trace.jl | 8 ++++++++ test/test_trace.jl | 23 +++++++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ae79819..ad89f2b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Umlaut" uuid = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841" authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"] -version = "0.5.1" +version = "0.5.2" [deps] CompilerPluginTools = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3638" diff --git a/src/trace.jl b/src/trace.jl index 920c1e8..3ae06ee 100644 --- a/src/trace.jl +++ b/src/trace.jl @@ -269,6 +269,14 @@ function rewrite_special_cases(st::Expr) if Meta.isexpr(ex, :splatnew) ex = Expr(:call, __splatnew__, ex.args...) end + # replace :($(Expr(:boundscheck))) with just `true` + if Meta.isexpr(ex, :boundscheck) + ex = true + end + # same for arguments + if ex isa Expr + ex.args = [Meta.isexpr(arg, :boundscheck) ? true : arg for arg in ex.args] + end return Meta.isexpr(st, :(=)) ? Expr(:(=), st.args[1], ex) : ex end rewrite_special_cases(st) = st diff --git a/test/test_trace.jl b/test/test_trace.jl index cdb2e37..1f7f699 100644 --- a/test/test_trace.jl +++ b/test/test_trace.jl @@ -44,6 +44,29 @@ end ############################################################################### +@eval _alt_getindex(x::Vector, i) = Base.arrayref($(Expr(:boundscheck)), x, i) +@eval function _boundscheck_foo() + v = $(Expr(:boundscheck)) + return v ? 1 : 0 +end + +@testset "trace: :boundscheck" begin + @testset "boundscheck as argument" begin + x = randn(5) + val, tape = trace(_alt_getindex, x, 1) + @test val == getindex(x, 1) + @test play!(tape, getindex, x, 1) == getindex(x, 1) + end + + @testset "boundscheck as rhs statement" begin + val, tape = trace(_boundscheck_foo) + @test val == _boundscheck_foo() + @test play!(tape, _boundscheck_foo) == val + end +end + +############################################################################### + @testset "trace: bcast" begin # bcast