Skip to content

Commit

Permalink
core: Make IRDLOperation results typed
Browse files Browse the repository at this point in the history
Now, defining a result with `resname = result_def(T)` will make
`resname` have the type `SSAValue[T]`.

This removes a lot of `cast` and `isa`/`isattr` in our codebase

stack-info: PR: #3991, branch: math-fehr/stack/6
  • Loading branch information
math-fehr committed Feb 28, 2025
1 parent 93f5abe commit 4d4bb1e
Show file tree
Hide file tree
Showing 25 changed files with 64 additions and 123 deletions.
4 changes: 2 additions & 2 deletions docs/Toy/toy/dialects/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def infer_shape(cls, op: Operation) -> None:
if isinstance(op_res_type := op.res.type, TensorType):
assert op_lhs_type.get_shape() == op_res_type.get_shape()
else:
op.res.type = op.lhs.type
op.res.type = op_lhs_type


@irdl_op_definition
Expand Down Expand Up @@ -312,7 +312,7 @@ def infer_shape(cls, op: Operation) -> None:
if isinstance(op_res_type := op.res.type, TensorType):
assert op_lhs_type.get_shape() == op_res_type.get_shape()
else:
op.res.type = op.lhs.type
op.res.type = op_lhs_type


@irdl_op_definition
Expand Down
2 changes: 1 addition & 1 deletion docs/Toy/toy/rewrites/lower_toy_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def match_and_rewrite(self, op: toy.ConstantOp, rewriter: PatternRewriter):
# When lowering the constant operation, we allocate and assign the constant
# values to a corresponding memref allocation.

tensor_type = cast(toy.TensorTypeF64, op.res.type)
tensor_type = op.res.type
memref_type = convert_tensor_to_memref(tensor_type)
alloc = insert_alloc_and_dealloc(memref_type, op, rewriter)

Expand Down
5 changes: 1 addition & 4 deletions docs/Toy/toy/rewrites/optimise_toy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import cast

from xdsl.dialects.builtin import (
DenseIntOrFPElementsAttr,
)
Expand Down Expand Up @@ -51,8 +49,7 @@ def match_and_rewrite(self, op: ReshapeOp, rewriter: PatternRewriter):
# Input defined by another transpose? If not, no match.
return

t = cast(TensorTypeF64, op.res.type)
new_op = ReshapeOp.from_input_and_type(reshape_input_op.arg, t)
new_op = ReshapeOp.from_input_and_type(reshape_input_op.arg, op.res.type)
rewriter.replace_matched_op(new_op)


Expand Down
8 changes: 3 additions & 5 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import IO, Literal, cast
from typing import IO, Literal

from xdsl.dialects import arith, csl, memref, scf
from xdsl.dialects.builtin import (
Expand Down Expand Up @@ -682,13 +682,11 @@ def print_block(self, body: Block):
self.variables[res] = f"({arr_name}[{idx_args}])"
case csl.AddressOfOp(value=val, res=res):
val_name = self._get_variable_name_for(val)
ty = cast(csl.PtrType, res.type)
use = self._var_use(res, ty.constness.data.value)
use = self._var_use(res, res.type.constness.data.value)
self.print(f"{use} = &{val_name};")

case csl.AddressOfFnOp(fn_name=name, res=res):
ty = cast(csl.PtrType, res.type)
use = self._var_use(res, ty.constness.data.value)
use = self._var_use(res, res.type.constness.data.value)
self.print(f"{use} = &{name.string_value()};")
case csl.DirectionOp(dir=d, res=res):
self._print_or_promote_to_inline_expr(res, str.upper(d.data))
Expand Down
2 changes: 1 addition & 1 deletion xdsl/backend/riscv/lowering/convert_memref_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def match_and_rewrite(self, op: memref.SubviewOp, rewriter: PatternRewriter):
source_type = source.type
assert isinstance(source_type, MemRefType)
source_type = cast(MemRefType[Attribute], source_type)
result_type = cast(MemRefType[Attribute], result.type)
result_type = result.type

result_layout_attr = result_type.layout
if isinstance(result_layout_attr, NoneAttr):
Expand Down
19 changes: 4 additions & 15 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,8 +1147,6 @@ class GetMemDsdOp(_GetDsdOp):
)

def verify_(self) -> None:
if not isinstance(self.result.type, DsdType):
raise VerifyException("DSD type is not DsdType")
if self.result.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd]:
raise VerifyException("DSD type must be memory DSD")
if self.result.type.data == DsdKind.mem1d_dsd and len(self.sizes) != 1:
Expand Down Expand Up @@ -1190,8 +1188,6 @@ class GetFabDsdOp(_GetDsdOp):
wavelet_index_offset = opt_prop_def(BoolAttr)

def verify_(self) -> None:
if not isinstance(self.result.type, DsdType):
raise VerifyException("DSD type is not DsdType")
if self.result.type.data not in [DsdKind.fabin_dsd, DsdKind.fabout_dsd]:
raise VerifyException("DSD type must be fabric DSD")
if len(self.sizes) != 1:
Expand Down Expand Up @@ -1226,8 +1222,7 @@ class SetDsdBaseAddrOp(IRDLOperation):

def verify_(self) -> None:
if (
not isinstance(self.result.type, DsdType)
or not isinstance(self.op.type, DsdType)
not isinstance(self.op.type, DsdType)
or self.result.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd]
or self.op.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd]
):
Expand Down Expand Up @@ -1266,8 +1261,7 @@ class IncrementDsdOffsetOp(IRDLOperation):

def verify_(self) -> None:
if (
not isinstance(self.result.type, DsdType)
or not isinstance(self.op.type, DsdType)
not isinstance(self.op.type, DsdType)
or self.result.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd]
or self.op.type.data not in [DsdKind.mem1d_dsd, DsdKind.mem4d_dsd]
):
Expand All @@ -1293,8 +1287,7 @@ class SetDsdLengthOp(IRDLOperation):

def verify_(self) -> None:
if (
not isinstance(self.result.type, DsdType)
or not isinstance(self.op.type, DsdType)
not isinstance(self.op.type, DsdType)
or self.result.type.data == DsdKind.mem4d_dsd
):
raise VerifyException(
Expand All @@ -1321,8 +1314,7 @@ class SetDsdStrideOp(IRDLOperation):

def verify_(self) -> None:
if (
not isinstance(self.result.type, DsdType)
or not isinstance(self.op.type, DsdType)
not isinstance(self.op.type, DsdType)
or self.result.type.data != DsdKind.mem1d_dsd
):
raise VerifyException(f"{self.name} can only operate on mem1d_dsd type")
Expand Down Expand Up @@ -1926,9 +1918,6 @@ def _verify_memref_addr(self, val_ty: MemRefType[Attribute], res_ty: PtrType):
)

def verify_(self) -> None:
if not isinstance(self.res.type, PtrType):
raise VerifyException("Result type must be a pointer")

val_ty = self.value.type
res_ty = self.res.type
if isa(val_ty, MemRefType[Attribute]):
Expand Down
6 changes: 0 additions & 6 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,6 @@ def get(

def verify_(self) -> None:
memref_type = self.memref.type
if not isinstance(memref_type, MemRefType):
raise VerifyException("expected result to be a memref")
memref_type = cast(MemRefType[Attribute], memref_type)

dyn_dims = [x for x in memref_type.shape.data if x.data == -1]
if len(dyn_dims) != len(self.dynamic_sizes):
Expand Down Expand Up @@ -345,9 +342,6 @@ def get(

def verify_(self) -> None:
memref_type = self.memref.type
if not isinstance(memref_type, MemRefType):
raise VerifyException("expected result to be a memref")
memref_type = cast(MemRefType[Attribute], memref_type)

dyn_dims = [x for x in memref_type.shape.data if x.data == -1]
if len(dyn_dims) != len(self.dynamic_sizes):
Expand Down
4 changes: 0 additions & 4 deletions xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,8 +1184,6 @@ def __init__(
def verify_(self) -> None:
if not self.writeonly:
return
if not isinstance(self.rd.type, IntRegisterType):
return
if self.rd.type.is_allocated and self.rd.type != Registers.ZERO:
raise VerifyException(
"When in 'writeonly' mode, destination must be register x0 (a.k.a. 'zero'), "
Expand Down Expand Up @@ -1330,8 +1328,6 @@ def __init__(
def verify_(self) -> None:
if self.writeonly is None:
return
if not isinstance(self.rd.type, IntRegisterType):
return
if self.rd.type.is_allocated and self.rd.type != Registers.ZERO:
raise VerifyException(
"When in 'writeonly' mode, destination must be register x0 (a.k.a. 'zero'), "
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def get_permutation(self) -> tuple[int, ...]:
def verify_(self) -> None:
# Operand and result types are checked before the custom `verify_`
o_type = cast(TensorType[Attribute], self.operand.type)
r_type = cast(TensorType[Attribute], self.result.type)
r_type = self.result.type

o_shape = o_type.get_shape()
r_shape = r_type.get_shape()
Expand Down
10 changes: 4 additions & 6 deletions xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ def verify_(self) -> None:
"buffer-semantic destination operands."
)
if len(self.res) > 0:
res_type = cast(TempType[Attribute], self.res[0].type)
res_type = self.res[0].type
for other in self.res[1:]:
other = cast(TempType[Attribute], other.type)
other = other.type
if res_type.bounds != other.bounds:
raise VerifyException(
"Expected all output types bounds to be equals."
Expand Down Expand Up @@ -674,7 +674,7 @@ def get_bounds(self):
return self.bounds
else:
assert self.res
res_type = cast(TempType[Attribute], self.res[0].type)
res_type = self.res[0].type
return res_type.bounds


Expand Down Expand Up @@ -1506,9 +1506,7 @@ def verify_(self) -> None:
types = [ot.elem if isinstance(ot, ResultType) else ot for ot in self.arg.types]
apply = cast(ApplyOp, self.parent_op())
if len(apply.res) > 0:
res_types = [
cast(TempType[Attribute], r.type).element_type for r in apply.res
]
res_types = [r.type.element_type for r in apply.res]
else:
res_types = [
cast(FieldType[Attribute], o.type).element_type for o in apply.dest
Expand Down
10 changes: 4 additions & 6 deletions xdsl/dialects/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,16 @@ def parse(cls, parser: Parser) -> Self:
return reshape

def verify_(self) -> None:
if (
not isinstance(source_type := self.source.type, TensorType)
or not isinstance(shape_type := self.shape.type, TensorType)
or not isinstance(res_type := self.result.type, TensorType)
):
if not isinstance(
source_type := self.source.type, TensorType
) or not isinstance(shape_type := self.shape.type, TensorType):
raise ValueError(
"tensor elementwise operation operands and result must be of type TensorType"
)

source_type = cast(TensorType[Attribute], source_type)
shape_type = cast(TensorType[Attribute], shape_type)
res_type = cast(TensorType[Attribute], res_type)
res_type = self.result.type

if source_type.element_type != res_type.element_type:
raise VerifyException(
Expand Down
3 changes: 1 addition & 2 deletions xdsl/interpreters/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from xdsl.interpreters.builtin import xtype_for_el_type
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Attribute
from xdsl.traits import SymbolTable


Expand All @@ -22,7 +21,7 @@ class MemRefFunctions(InterpreterFunctions):
def run_alloc(
self, interpreter: Interpreter, op: memref.AllocOp, args: PythonValues
) -> PythonValues:
memref_type = cast(memref.MemRefType[Attribute], op.memref.type)
memref_type = op.memref.type

shape = memref_type.get_shape()
size = prod(shape)
Expand Down
3 changes: 0 additions & 3 deletions xdsl/interpreters/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,6 @@ def run_get_register(
) -> PythonValues:
attr = op.res.type

if not isinstance(attr, riscv.RISCVRegisterType):
raise InterpretationError(f"Unexpected type {attr}, expected register type")

if not attr.is_allocated:
raise InterpretationError(
f"Cannot get value for unallocated register {attr}"
Expand Down
2 changes: 0 additions & 2 deletions xdsl/interpreters/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from xdsl.interpreters.builtin import xtype_for_el_type
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.interpreters.utils.ptr import TypedPtr
from xdsl.ir import Attribute
from xdsl.utils.exceptions import InterpretationError


Expand All @@ -24,7 +23,6 @@ def run_empty(
) -> tuple[Any, ...]:
result_type = op.tensor.type
assert isinstance(result_type, TensorType)
result_type = cast(TensorType[Attribute], result_type)
result_shape = list(result_type.get_shape())
xtype = xtype_for_el_type(result_type.element_type, interpreter.index_bitwidth)
return (
Expand Down
26 changes: 15 additions & 11 deletions xdsl/irdl/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def __init__(
self.constr = range_constr_coercion(attr)


class VarOpResult(tuple[OpResult, ...]):
class VarOpResult(Generic[AttributeInvT], tuple[OpResult[AttributeInvT], ...]):
@property
def types(self):
return tuple(r.type for r in self)
Expand All @@ -411,7 +411,7 @@ class OptResultDef(VarResultDef, OptionalDef):
"""An IRDL optional result definition."""


OptOpResult: TypeAlias = OpResult | None
OptOpResult: TypeAlias = OpResult[AttributeInvT] | None


@dataclass(init=True)
Expand Down Expand Up @@ -596,42 +596,46 @@ class _SuccessorFieldDef(_OpDefField[SuccessorDef]):


def result_def(
constraint: IRDLAttrConstraint = Attribute,
constraint: IRDLGenericAttrConstraint[AttributeInvT] = Attribute,
*,
default: None = None,
resolver: None = None,
init: Literal[False] = False,
) -> OpResult:
) -> OpResult[AttributeInvT]:
"""
Defines a result of an operation.
"""
return cast(OpResult, _ResultFieldDef(ResultDef, constraint))
return cast(OpResult[AttributeInvT], _ResultFieldDef(ResultDef, constraint))


def var_result_def(
constraint: RangeConstraint | IRDLAttrConstraint = Attribute,
constraint: (
GenericRangeConstraint[AttributeInvT] | IRDLGenericAttrConstraint[AttributeInvT]
) = Attribute,
*,
default: None = None,
resolver: None = None,
init: Literal[False] = False,
) -> VarOpResult:
) -> VarOpResult[AttributeInvT]:
"""
Defines a variadic result of an operation.
"""
return cast(VarOpResult, _ResultFieldDef(VarResultDef, constraint))
return cast(VarOpResult[AttributeInvT], _ResultFieldDef(VarResultDef, constraint))


def opt_result_def(
constraint: RangeConstraint | IRDLAttrConstraint = Attribute,
constraint: (
GenericRangeConstraint[AttributeInvT] | IRDLGenericAttrConstraint[AttributeInvT]
) = Attribute,
*,
default: None = None,
resolver: None = None,
init: Literal[False] = False,
) -> OptOpResult:
) -> OptOpResult[AttributeInvT]:
"""
Defines an optional result of an operation.
"""
return cast(OptOpResult, _ResultFieldDef(OptResultDef, constraint))
return cast(OptOpResult[AttributeInvT], _ResultFieldDef(OptResultDef, constraint))


def prop_def(
Expand Down
Loading

0 comments on commit 4d4bb1e

Please sign in to comment.