Skip to content

Commit

Permalink
More cleaning up, but that belongs to the PR.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed May 16, 2024
1 parent 56a309e commit e0d5a52
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 53 deletions.
60 changes: 15 additions & 45 deletions src/jace/translator/jaxpr_translator_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import itertools
from collections.abc import Iterable, Mapping, MutableSequence, Sequence
from typing import Any, Final, cast, overload, Literal
from typing import Any, Final, Literal, cast, overload

import dace
import jax
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
Notes:
`sub_translators` is not copied, thus the user has to guarantee,
that it will not change during translation.
It is highly advised but not requiered to use the output of
It is highly advised but not required to use the output of
`get_subtranslators()` or pass a copy as argument.
"""

Expand Down Expand Up @@ -105,7 +105,6 @@ def translate_jaxpr(
inp_scalar_as_array: bool = False,
name: str | None = None,
reserved_names: str | Iterable[str] = (),
allow_empty_jaxpr: bool = False,
) -> translator.TranslatedJaxprSDFG:
"""Perform the translation of a Jaxpr into a SDFG.
Expand All @@ -121,19 +120,9 @@ def translate_jaxpr(
inp_scalar_as_array: Translate scalar _input_ arguments to arrays of length 1.
name: Use this name for the SDFG instead some generated one.
reserved_names: Prevent the generation of variables with these names, see `self.add_array()` for more.
allow_empty_jaxpr: Allows empty Jaxpr.
Notes:
Every time this function is called a new revision index is generated.
"""
if (len(jaxpr.eqns) == 0) and (not allow_empty_jaxpr):
raise ValueError("Passed an empty Jaxpr, but did not allow for empty Jaxpr.")
if not isinstance(jaxpr, jax_core.ClosedJaxpr):
raise TypeError(f"Expected a 'jax.core.ClosedJaxp' instance but got '{type(jaxpr)}'")
if len(jaxpr.effects) != 0:
raise NotImplementedError("'Jaxpr' with side effects are not supported.")
if len(jaxpr.out_avals) == 0:
raise ValueError("Jaxpr has zero output variables.")
if not jax.config.read("jax_enable_x64"):
raise NotImplementedError("The translation only works if 'jax_enable_x64' is enabled.")

Expand Down Expand Up @@ -164,7 +153,6 @@ def append_new_state(
label: str | None = None,
condition: dprop.CodeBlock | None = None,
assignments: Mapping[str, Any] | None = None,
*,
prev_state: dace.SDFGState | None = None,
) -> dace.SDFGState:
"""Creates a new `SDFGState` and adds it to the SDFG.
Expand Down Expand Up @@ -334,7 +322,7 @@ def add_jax_name_mapping(
jax_var: The Jax variable.
sdfg_name: The name of the corresponding SDFG variable.
"""
assert isinstance(sdfg_name, str) and (len(sdfg_name) > 0) # noqa: PT018 # Should be one assertion.
assert len(sdfg_name) > 0

if jax_var in self._ctx.jax_name_map:
if self._ctx.jax_name_map[jax_var] == sdfg_name: # noops.
Expand All @@ -361,10 +349,6 @@ def add_reserved_names(
return self
if isinstance(reserved_names, str):
reserved_names = [reserved_names]
elif isinstance(reserved_names, Iterable):
pass
else:
raise TypeError(f"Does not know how to handle the type '{type(reserved_names)}'.")
self._reserved_names.update(reserved_names)
return self

Expand Down Expand Up @@ -423,9 +407,7 @@ def add_array(
If you need to create a special array, you can use `jace.util.JaCeVar`
to create a pseudo Jax variable.
"""
assert self.is_allocated()

shape: Sequence[int] = util.get_jax_var_shape(arg)
shape: tuple[int] = util.get_jax_var_shape(arg)
dtype = util.get_jax_var_dtype(arg)
offset = None # i.e. no offset
storage: dace.StorageType = dace.StorageType.Default # Set at later stages (optimization)
Expand All @@ -452,7 +434,6 @@ def add_array(
find_new_name = False
alt_name = util._propose_jax_name(arg, self._ctx.jax_name_map)
if alt_name is not None:
assert isinstance(alt_name, str)
find_new_name = False # If a name was given, then use it no matter what.
if len(alt_name) == 0:
raise ValueError("Passed an empty 'alt_name'.")
Expand All @@ -468,18 +449,17 @@ def add_array(
raise ValueError(
f"Specified 'name_prefix' ('{name_prefix}') but passed '{alt_name}' as 'alt_name'."
)
if name_prefix is not None:
assert isinstance(name_prefix, str)
if len(name_prefix) == 0:
raise ValueError("Specified an empty 'name_prefix'.")
if (name_prefix is not None) and (len(name_prefix) == 0):
raise ValueError("Specified an empty 'name_prefix'.")

# Checking the strides.
if strides is not None:
if is_scalar:
raise ValueError("Specified a stride for a scalar.")
if isinstance(strides, (str, dace.symbol, int)):
strides = (strides,)
assert isinstance(strides, tuple)
elif not isinstance(strides, tuple):
strides = tuple(strides)
if len(strides) != len(shape):
raise ValueError(
f"'strides' has length {len(strides)}, but array rank is {len(shape)}."
Expand All @@ -499,8 +479,6 @@ def add_array(
raise NotImplementedError("Jax Literals are not supported.")
if alt_name is None:
raise ValueError(f"Passed literal '{arg}', but not specified a name to use.")
else:
raise TypeError(f"Does not know how to handle '{type(arg).__name__}'.")

if alt_name is None:
# If we are the root translator, then we will use `prop_name` directly;
Expand Down Expand Up @@ -623,15 +601,17 @@ def create_jax_var_list( # type: ignore[misc]
"""
if only_creation and prevent_creation:
raise ValueError("Specified both 'only_creation' and 'prevent_creation'.")
assert "update_var_mapping" not in kwargs
assert (
"update_var_mapping" not in kwargs
), "You can not pass 'update_var_mapping' as argument to 'create_jax_var_list()'."

ret_list: list[None | str] = []
for jax_var in jax_var_list:
if isinstance(jax_var, jax_core.Literal):
if not handle_literals:
raise ValueError("Encountered a literal but `handle_literals` was `False`.")
sdfg_name = None
elif isinstance(jax_var, (jax_core.Var, util.JaCeVar)):
else:
mapped_sdfg_name: str | None = self.map_jax_var_to_sdfg(jax_var, allow_fail=True)
if (mapped_sdfg_name is None) and prevent_creation:
raise ValueError(f"'prevent_creation' given but have to create '{jax_var}'.")
Expand All @@ -643,8 +623,6 @@ def create_jax_var_list( # type: ignore[misc]
sdfg_name = mapped_sdfg_name
# Calling `add_jax_name_mapping` is save, because if the mapping does already exists it is a no ops.
self.add_jax_name_mapping(jax_var, sdfg_name)
else:
raise TypeError(f"Does not know how to handle '{type(jax_var).__name__}'")

ret_list.append(sdfg_name)

Expand All @@ -671,7 +649,6 @@ def _create_initial_input(
raise RuntimeError("Driver is not allocated, can not create constants.")
if len(self._ctx.inp_names) != 0:
raise RuntimeError("Called '_create_initial_input()' twice?")
assert len(self._ctx.out_names) == 0

# Handle the initial input arguments
sdfg: dace.SDFG = self._ctx.sdfg
Expand Down Expand Up @@ -709,7 +686,7 @@ def _create_constants(
if not self.is_allocated():
raise RuntimeError("Driver is not allocated, can not create constants.")
if len(jaxpr.consts) == 0:
return []
return ()

sdfg_const_names: Sequence[str] = self.create_jax_var_list(
jax_var_list=jaxpr.jaxpr.constvars,
Expand Down Expand Up @@ -830,9 +807,7 @@ def _translate_single_eqn(
# Find the subtranslator
prim_name: str = eqn.primitive.name
if prim_name not in self._sub_translators:
raise NotImplementedError(
f"No subtranslators known to handle '{prim_name}' || {type(self._sub_translators)}."
)
raise NotImplementedError(f"No subtranslators known to handle '{prim_name}'.")
subtranslator = self._sub_translators[prim_name]

# Create the state into which the equation should be translated
Expand All @@ -856,11 +831,6 @@ def _translate_single_eqn(
if eqn_state is not self._ctx.terminal_state:
raise RuntimeError("Inconsistent terminal state was detected.")
new_sdfg_term_state = eqn_state
elif isinstance(new_sdfg_term_state, dace.SDFGState):
# TODO(phimuell): use `last_term_state` to test if `new_sdfg_term_state` is reachable.
pass
else:
raise TypeError(f"Encountered illegal types '{type(new_sdfg_term_state)}'")

# In case a subtranslator decided to not use the variables we created for it, which is allowed
# but he must update the `out_var_names` list correctly, we will now verify this.
Expand Down Expand Up @@ -898,7 +868,7 @@ def _translate_jaxpr_internal(
Such variables are included by some transformations such as `grad()`.
"""
nb_translated_eqn: int = 0
out_var_names: Sequence[str] = []
out_var_names: Sequence[str] = ()

# Translate the equations one by one.
for eqn in jaxpr.jaxpr.eqns:
Expand Down
1 change: 0 additions & 1 deletion src/jace/translator/managing.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def wrapper(real_fun: Callable) -> PrimitiveTranslator:

return wrapper

assert inspect.isfunction(fun)
if getattr(fun, "primitive", prim_name) != prim_name:
raise ValueError(f"Passed 'fun' already '{fun.primitive}' as 'primitive' property.") # type: ignore[attr-defined]

Expand Down
10 changes: 3 additions & 7 deletions src/jace/util/jax_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,12 @@ def _propose_jax_name(
raise RuntimeError(
f"Can not propose a second name for '{jax_var}', it already known as '{jax_name_map[jax_var]}'."
)
if isinstance(jax_var, jax_core.Var):
pass
elif isinstance(jax_var, JaCeVar):
if isinstance(jax_var, JaCeVar) and (jax_var.name != ""):
# If the name of the JaCe variable is empty, then use the name proposing
# technique used for Jax variables; Mostly used for debugging.
if jax_var.name != "":
return jax_var.name
else:
raise TypeError(f"Can not propose a name for '{jax_var}'")
return jax_var.name

# This code is taken from the Jax source.
c = len(jax_name_map)
jax_name = ""
while len(jax_name) == 0 or c != 0:
Expand Down

0 comments on commit e0d5a52

Please sign in to comment.