From af5ae5f6e1e52170ff810b00dc3e5a1c22685d01 Mon Sep 17 00:00:00 2001
From: Will Tebbutt <wt0881@my.bristol.ac.uk>
Date: Sat, 29 Jul 2023 13:10:31 +0100
Subject: [PATCH] Boundscheck implementation + test

---
 Project.toml       |  2 +-
 src/trace.jl       |  8 ++++++++
 test/test_trace.jl | 23 +++++++++++++++++++++++
 3 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index ae79819..ad89f2b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
 name = "Umlaut"
 uuid = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841"
 authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
-version = "0.5.1"
+version = "0.5.2"
 
 [deps]
 CompilerPluginTools = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3638"
diff --git a/src/trace.jl b/src/trace.jl
index 920c1e8..3ae06ee 100644
--- a/src/trace.jl
+++ b/src/trace.jl
@@ -269,6 +269,14 @@ function rewrite_special_cases(st::Expr)
     if Meta.isexpr(ex, :splatnew)
         ex = Expr(:call, __splatnew__, ex.args...)
     end
+    # replace :($(Expr(:boundscheck))) with just `true`
+    if Meta.isexpr(ex, :boundscheck)
+        ex = true
+    end
+    # same for arguments
+    if ex isa Expr
+        ex.args = [Meta.isexpr(arg, :boundscheck) ? true : arg for arg in ex.args]
+    end
     return Meta.isexpr(st, :(=)) ? Expr(:(=), st.args[1], ex) : ex
 end
 rewrite_special_cases(st) = st
diff --git a/test/test_trace.jl b/test/test_trace.jl
index cdb2e37..1f7f699 100644
--- a/test/test_trace.jl
+++ b/test/test_trace.jl
@@ -44,6 +44,29 @@ end
 
 ###############################################################################
 
+@eval _alt_getindex(x::Vector, i) = Base.arrayref($(Expr(:boundscheck)), x, i)
+@eval function _boundscheck_foo()
+    v = $(Expr(:boundscheck))
+    return v ? 1 : 0
+end
+
+@testset "trace: :boundscheck" begin
+    @testset "boundscheck as argument" begin
+        x = randn(5)
+        val, tape = trace(_alt_getindex, x, 1)
+        @test val == getindex(x, 1)
+        @test play!(tape, getindex, x, 1) == getindex(x, 1)
+    end
+
+    @testset "boundscheck as rhs statement" begin
+        val, tape = trace(_boundscheck_foo)
+        @test val == _boundscheck_foo()
+        @test play!(tape, _boundscheck_foo) == val
+    end
+end
+
+###############################################################################
+
 
 @testset "trace: bcast" begin
     # bcast