Skip to content

Commit

Permalink
backend: (riscv) use infinite register helper in more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Mar 3, 2025
1 parent c9493b7 commit 2e0e574
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 22 deletions.
15 changes: 7 additions & 8 deletions tests/backend/riscv/test_register_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def test_default_reserved_registers():

unallocated = riscv.Registers.UNALLOCATED_INT

def j(index: int):
return riscv.IntRegisterType(f"j_{index}")
j = riscv.IntRegisterType.infinite_register

assert register_queue.pop(riscv.IntRegisterType) == j(0)

Expand Down Expand Up @@ -116,18 +115,18 @@ def get_register_constraints(self) -> RegisterConstraints:
# All new registers. The result register is reused by the allocator for the operand.
op0 = MyInstructionOp.get("", "", "", "")
register_allocator.process_riscv_op(op0)
assert op0.rs0.type == riscv.IntRegisterType("j_1")
assert op0.rs1.type == riscv.IntRegisterType("j_0")
assert op0.rd0.type == riscv.IntRegisterType("j_1")
assert op0.rd1.type == riscv.IntRegisterType("j_0")
assert op0.rs0.type == riscv.IntRegisterType.infinite_register(1)
assert op0.rs1.type == riscv.IntRegisterType.infinite_register(0)
assert op0.rd0.type == riscv.IntRegisterType.infinite_register(1)
assert op0.rd1.type == riscv.IntRegisterType.infinite_register(0)

# One register reserved for inout parameter, the allocator should allocate the output
# to the same register.
op1 = MyInstructionOp.get("", "", "", "a0")
register_allocator.process_riscv_op(op1)
assert op1.rs0.type == riscv.IntRegisterType("j_2")
assert op1.rs0.type == riscv.IntRegisterType.infinite_register(2)
assert op1.rs1.type == riscv.IntRegisterType("a0")
assert op1.rd0.type == riscv.IntRegisterType("j_2")
assert op1.rd0.type == riscv.IntRegisterType.infinite_register(2)
assert op1.rd1.type == riscv.IntRegisterType("a0")


Expand Down
51 changes: 38 additions & 13 deletions tests/backend/riscv/test_register_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ def test_default_reserved_registers():
def test_push_j_register():
register_queue = RiscvRegisterQueue()

register_queue.push(riscv.IntRegisterType("j_0"))
assert riscv.IntRegisterType("j_0") == register_queue.available_int_registers[-1]
register_queue.push(riscv.IntRegisterType.infinite_register(0))
assert (
riscv.IntRegisterType.infinite_register(0)
== register_queue.available_int_registers[-1]
)

register_queue.push(riscv.FloatRegisterType("j_0"))
register_queue.push(riscv.FloatRegisterType.infinite_register(0))
assert (
riscv.FloatRegisterType("j_0") == register_queue.available_float_registers[-1]
riscv.FloatRegisterType.infinite_register(0)
== register_queue.available_float_registers[-1]
)


Expand All @@ -47,18 +51,39 @@ def test_push_register():
def test_reserve_register():
register_queue = RiscvRegisterQueue()

register_queue.reserve_register(riscv.IntRegisterType("j_0"))
assert register_queue.reserved_int_registers[riscv.IntRegisterType("j_0")] == 1
register_queue.reserve_register(riscv.IntRegisterType.infinite_register(0))
assert (
register_queue.reserved_int_registers[
riscv.IntRegisterType.infinite_register(0)
]
== 1
)

register_queue.reserve_register(riscv.IntRegisterType("j_0"))
assert register_queue.reserved_int_registers[riscv.IntRegisterType("j_0")] == 2
register_queue.reserve_register(riscv.IntRegisterType.infinite_register(0))
assert (
register_queue.reserved_int_registers[
riscv.IntRegisterType.infinite_register(0)
]
== 2
)

register_queue.unreserve_register(riscv.IntRegisterType("j_0"))
assert register_queue.reserved_int_registers[riscv.IntRegisterType("j_0")] == 1
register_queue.unreserve_register(riscv.IntRegisterType.infinite_register(0))
assert (
register_queue.reserved_int_registers[
riscv.IntRegisterType.infinite_register(0)
]
== 1
)

register_queue.unreserve_register(riscv.IntRegisterType("j_0"))
assert riscv.IntRegisterType("j_0") not in register_queue.reserved_int_registers
assert riscv.IntRegisterType("j_0") not in register_queue.available_int_registers
register_queue.unreserve_register(riscv.IntRegisterType.infinite_register(0))
assert (
riscv.IntRegisterType.infinite_register(0)
not in register_queue.reserved_int_registers
)
assert (
riscv.IntRegisterType.infinite_register(0)
not in register_queue.available_int_registers
)

# Check assertion error when reserving an available register
reg = register_queue.pop(riscv.IntRegisterType)
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_add_op():
assert a2.type.index == IntAttr(12)

# Registers that aren't predefined should not have an index.
assert isinstance(riscv.IntRegisterType("j_1").index, NoneAttr)
assert isinstance(riscv.IntRegisterType.infinite_register(1).index, NoneAttr)


def test_csr_op():
Expand Down

0 comments on commit 2e0e574

Please sign in to comment.