Skip to content

Commit

Permalink
Added Enrique's fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Oct 3, 2024
1 parent 29d8452 commit 95df5e2
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 18 deletions.
14 changes: 7 additions & 7 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from jace import stages, translator, util


__all__ = ["DEFAUL_BACKEND", "JITOptions", "grad", "jacfwd", "jacrev", "jit"]
__all__ = ["DEFAULT_BACKEND", "JITOptions", "grad", "jacfwd", "jacrev", "jit"]

_P = ParamSpec("_P")

DEFAUL_BACKEND: Final[str] = "cpu"
DEFAULT_BACKEND: Final[str] = "cpu"


class JITOptions(TypedDict, total=False):
Expand All @@ -34,11 +34,11 @@ class JITOptions(TypedDict, total=False):
`jace.jit`. Furthermore, some additional ones might be supported.
Args:
backend: Target platform for which DaCe should generate code. Supported values
are `'cpu'` or `'gpu'`.
backend: Target platform for which DaCe should generate code. Supported values
are `'cpu'` or `'gpu'`.
"""

backend: str
backend: Literal["cpu", "gpu"]


@overload
Expand Down Expand Up @@ -87,7 +87,7 @@ def jit(
raise ValueError(
f"The following arguments to 'jace.jit' are not supported: {', '.join(not_supported_jit_keys)}."
)
if kwargs.get("backend", DEFAUL_BACKEND).lower() not in {"cpu", "gpu"}:
if kwargs.get("backend", DEFAULT_BACKEND).lower() not in {"cpu", "gpu"}:
raise ValueError(f"The backend '{kwargs['backend']}' is not supported.")

def wrapper(f: Callable[_P, Any]) -> stages.JaCeWrapped[_P]:
Expand All @@ -99,7 +99,7 @@ def wrapper(f: Callable[_P, Any]) -> stages.JaCeWrapped[_P]:
else primitive_translators
),
jit_options=kwargs,
device=util.parse_backend_jit_option(kwargs.get("backend", DEFAUL_BACKEND)),
device=util.to_device_type(kwargs.get("backend", DEFAULT_BACKEND)),
)
functools.update_wrapper(jace_wrapper, f)
return jace_wrapper
Expand Down
4 changes: 2 additions & 2 deletions src/jace/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def make_jaxpr(

# NOTE: In the current implementation we are using `jax.make_jaxpr()`. But this
# is a different implementation than `jax.jit()` uses. The main difference
# between the two, seems to be the set of arguments that are supported. In JaCe,
# between the two seems to be the set of arguments that are supported. In JaCe,
# however, we want to support all arguments that `jace.jit()` does.
# For establishing compatibility we have to clear the arguments to make them
# compatible, with what `jax.make_jaxpr()` and `jace.jit()` supports.
# compatible with what `jax.make_jaxpr()` and `jace.jit()` supports.
trace_options = {}

def tracer_impl(
Expand Down
4 changes: 2 additions & 2 deletions src/jace/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
get_jax_var_shape,
is_tracing_ongoing,
move_into_jax_array,
parse_backend_jit_option,
propose_jax_name,
to_device_type,
translate_dtype,
)
from .traits import (
Expand Down Expand Up @@ -53,7 +53,7 @@
"is_scalar",
"is_tracing_ongoing",
"move_into_jax_array",
"parse_backend_jit_option",
"propose_jax_name",
"to_device_type",
"translate_dtype",
]
8 changes: 1 addition & 7 deletions src/jace/util/jax_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
import numpy as np


try:
import cupy as cp # type: ignore[import-not-found]
except ImportError:
cp = None


@dataclasses.dataclass(repr=True, frozen=True, eq=False)
class JaCeVar:
"""
Expand Down Expand Up @@ -266,7 +260,7 @@ def get_jax_literal_value(lit: jax_core.Atom) -> bool | float | int | np.generic
raise TypeError(f"Failed to extract value from '{lit}' ('{val}' type: {type(val).__name__}).")


def parse_backend_jit_option(
def to_device_type(
backend: str | dace.DeviceType,
) -> dace.DeviceType:
"""Turn JAX' `backend` option into the proper DaCe device type."""
Expand Down

0 comments on commit 95df5e2

Please sign in to comment.