Skip to content

Commit

Permalink
feat: support binary backend in compiled mode
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Feb 14, 2025
1 parent 35a4dac commit bd60991
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 14 deletions.
41 changes: 27 additions & 14 deletions lib/emlx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -318,25 +318,32 @@ defmodule EMLX do

@impl Nx.Defn.Compiler
def __jit__(key, vars, fun, args_list, opts) do
# TODO: instead of checking the backend here,
# we should automatically convert from binary backend to EMLX backend
# given a device optionsg
case Nx.default_backend() do
EMLX.Backend ->
:ok

{EMLX.Backend, _} ->
:ok

other ->
raise ArgumentError, "EMLX can only be used with the EMLX backend, got: #{inspect(other)}"
end

__compile__(key, vars, fun, opts).(args_list)
end

@impl Nx.Defn.Compiler
def __compile__(key, vars, fun, opts) do
backend = Nx.default_backend()

target_backend =
case backend do
EMLX.Backend ->
backend

{EMLX.Backend, _} ->
backend

Nx.BinaryBackend ->
EMLX.Backend

{Nx.BinaryBackend, _} ->
EMLX.Backend

other ->
raise ArgumentError,
"EMLX can only be used with the EMLX.Backend or Nx.BinaryBackend, got: #{inspect(other)}"
end

expr = fun.(vars)

fn [args] ->
Expand All @@ -346,6 +353,12 @@ defmodule EMLX do
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} ->
{device, ref}

%Nx.Tensor{data: %Nx.BinaryBackend{}} = t ->
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} =
Nx.backend_copy(t, target_backend)

{device, ref}

other ->
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} = Nx.to_tensor(other)
{device, ref}
Expand Down
26 changes: 26 additions & 0 deletions test/emlx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,30 @@ defmodule EMLXTest do
assert_equal(left, Nx.tensor(3))
assert_equal(right, Nx.tensor(-1))
end

test "__jit__ supports binary backend in arguments" do
{left, right} =
Nx.Defn.jit_apply(
&{Nx.add(&1, &2), Nx.subtract(&1, &2)},
[Nx.tensor(1, backend: Nx.BinaryBackend), 2],
compiler: EMLX
)

assert_equal(left, Nx.tensor(3))
assert_equal(right, Nx.tensor(-1))
end

test "__jit__ supports binary backend as the default backend" do
Nx.with_default_backend(Nx.BinaryBackend, fn ->
{left, right} =
Nx.Defn.jit_apply(
&{Nx.add(&1, &2), Nx.subtract(&1, &2)},
[Nx.tensor(1), 2],
compiler: EMLX
)

assert_equal(left, Nx.tensor(3))
assert_equal(right, Nx.tensor(-1))
end)
end
end

0 comments on commit bd60991

Please sign in to comment.