From 8a7078184f9b0b1f40b60268031d37b62ebeaa13 Mon Sep 17 00:00:00 2001 From: David K Date: Mon, 2 Sep 2024 15:46:31 +0100 Subject: [PATCH] dialects: (csl) added `csl.addressof_fn` (#3135) Functions like `csl.addressof`, but takes a prop `fn_name` instead of an SSAValue. Result type is a single const pointer to a function (this is a limitation imposed by the CSL language. --- tests/filecheck/backend/csl/print_csl.mlir | 6 ++++ tests/filecheck/dialects/csl/ops.mlir | 4 +++ xdsl/backend/csl/print_csl.py | 5 ++++ xdsl/dialects/csl/csl.py | 34 ++++++++++++++++++++++ 4 files changed, 49 insertions(+) diff --git a/tests/filecheck/backend/csl/print_csl.mlir b/tests/filecheck/backend/csl/print_csl.mlir index 552e34a9cf..a671fa7bad 100644 --- a/tests/filecheck/backend/csl/print_csl.mlir +++ b/tests/filecheck/backend/csl/print_csl.mlir @@ -114,6 +114,10 @@ %ptr_to_arr = "csl.addressof"(%uninit_array) : (memref<10xf32>) -> !csl.ptr, #csl, #csl> %ptr_to_val = "csl.addressof"(%const27) : (i16) -> !csl.ptr, #csl> + %ptr_1_fn = "csl.addressof_fn"() <{fn_name = @args_no_return}> : () -> !csl.ptr<(i32, i32) -> (), #csl, #csl> + %ptr_2_fn = "csl.addressof_fn"() <{fn_name = @no_args_return}> : () -> !csl.ptr<() -> (f32), #csl, #csl> + + "csl.export"(%global_ptr) <{ type = !csl.ptr, #csl>, @@ -453,6 +457,8 @@ csl.func @builtins() { // CHECK-NEXT: const const_ptr : [*]const i32 = &const_array; // CHECK-NEXT: var ptr_to_arr : *[10]f32 = &uninit_array; // CHECK-NEXT: const ptr_to_val : *const i16 = &const27; +// CHECK-NEXT: const ptr_1_fn : *const fn(i32, i32) void = &args_no_return; +// CHECK-NEXT: const ptr_2_fn : *const fn() f32 = &no_args_return; // CHECK-NEXT: comptime { // CHECK-NEXT: @export_symbol(global_ptr, "ptr_name"); // CHECK-NEXT: } diff --git a/tests/filecheck/dialects/csl/ops.mlir b/tests/filecheck/dialects/csl/ops.mlir index 7352ad283e..3bfff292ee 100644 --- a/tests/filecheck/dialects/csl/ops.mlir +++ b/tests/filecheck/dialects/csl/ops.mlir @@ -82,6 +82,8 @@ csl.func @initialize() { %many_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr, #csl> %single_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr, #csl, #csl> + %function_ptr = "csl.addressof_fn"() <{fn_name = @initialize}> : () -> !csl.ptr<() -> (), #csl, #csl> + %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3, 4], "offsets" = [1, 2]}> : (memref<10xf32>, i32, i32) -> !csl %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl @@ -370,6 +372,7 @@ csl.func @builtins() { // CHECK-NEXT: %scalar_ptr = "csl.addressof"(%scalar) : (i32) -> !csl.ptr, #csl> // CHECK-NEXT: %many_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr, #csl> // CHECK-NEXT: %single_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr, #csl, #csl> +// CHECK-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl // CHECK-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl // CHECK-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl @@ -605,6 +608,7 @@ csl.func @builtins() { // CHECK-GENERIC-NEXT: %scalar_ptr = "csl.addressof"(%scalar) : (i32) -> !csl.ptr, #csl> // CHECK-GENERIC-NEXT: %many_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr, #csl> // CHECK-GENERIC-NEXT: %single_arr_ptr = "csl.addressof"(%arr) : (memref<10xf32>) -> !csl.ptr, #csl, #csl> +// CHECK-GENERIC-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-GENERIC-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl // CHECK-GENERIC-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl // CHECK-GENERIC-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl diff --git a/xdsl/backend/csl/print_csl.py b/xdsl/backend/csl/print_csl.py index 265bb67ebe..2b7282d5d5 100644 --- a/xdsl/backend/csl/print_csl.py +++ b/xdsl/backend/csl/print_csl.py @@ -544,6 +544,11 @@ def print_block(self, body: Block): ty = cast(csl.PtrType, res.type) use = self._var_use(res, ty.constness.data.value) self.print(f"{use} = &{val_name};") + + case csl.AddressOfFnOp(fn_name=name, res=res): + ty = cast(csl.PtrType, res.type) + use = self._var_use(res, ty.constness.data.value) + self.print(f"{use} = &{name.string_value()};") case csl.SymbolExportOp(value=val, type=ty) as exp: name = exp.get_name() q_name = f'"{name}"' diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index 5e94714e31..b38b6fe600 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -1529,6 +1529,39 @@ def verify_(self) -> None: return super().verify_() +@irdl_op_definition +class AddressOfFnOp(IRDLOperation): + """ + Takes the address of a function from symbol ref. + + Result has to have kind SINGLE and constness CONST + """ + + name = "csl.addressof_fn" + fn_name = prop_def(SymbolRefAttr) + + res = result_def(PtrType) + + def __init__(self, fn_name: str | SymbolRefAttr): + if isinstance(fn_name, str): + fn_name = SymbolRefAttr(fn_name) + + super().__init__(properties={"fn_name": fn_name}) + + def verify_(self) -> None: + ty = self.res.type + assert isa(ty, PtrType) + if not isa(ty.type, FunctionType): + raise VerifyException("Pointed to type must be a function type") + if ty.kind.data != PtrKind.SINGLE: + raise VerifyException("Pointer kind must be 'single'") + + if ty.constness.data != PtrConst.CONST: + raise VerifyException("Function pointers must be const") + + return super().verify_() + + @irdl_op_definition class AddressOfOp(IRDLOperation): """ @@ -1733,6 +1766,7 @@ def __init__(self, struct_a: Operation, struct_b: Operation): [ Add16Op, Add16cOp, + AddressOfFnOp, AddressOfOp, And16Op, CallOp,