Skip to content

Commit

Permalink
feat: added experimental support for linux (#74)
Browse files Browse the repository at this point in the history
* feat: added experimental support for linux
* include cstring for std::memcpy
* warnings about unsupported targets
* allow users to supply custom libmlx archive for unsupported target
* updated README.md and ignore changes in .md files for ci
* skip non-finite to u8 conversion on x86_64 linux

Signed-off-by: Cocoa <[email protected]>
  • Loading branch information
cocoa-xu authored Feb 13, 2025
1 parent 9e318c9 commit 6afc1a7
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 13 deletions.
59 changes: 59 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,68 @@ on:
push:
branches:
- main
paths-ignore:
- '*.md'
- '**/*.md'
pull_request:
paths-ignore:
- '*.md'
- '**/*.md'

jobs:
linux:
name: ${{ matrix.job.arch }}-linux-gnu (${{ matrix.job.elixir }}, ${{ matrix.job.otp }})
runs-on: ${{ matrix.job.runs-os }}
strategy:
fail-fast: false
matrix:
job:
- { arch: "x86_64", runs-os: ubuntu-latest, otp: "25.3.2.15", elixir: "1.15.4" }
- { arch: "aarch64", runs-os: ubuntu-24.04-arm, otp: "25.3.2.15", elixir: "1.15.4" }
env:
MIX_ENV: test
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Setup Elixir and Erlang
id: setup
run: |
curl -fsSO https://elixir-lang.org/install.sh
sh install.sh elixir@${{ matrix.job.elixir }} otp@${{ matrix.job.otp }}
OTP_VERSION="${{ matrix.job.otp }}"
OTP_MAJOR="${OTP_VERSION%%.*}"
export OTP_PATH=$HOME/.elixir-install/installs/otp/${OTP_VERSION}/bin
export ELIXIR_PATH=$HOME/.elixir-install/installs/elixir/${{ matrix.job.elixir }}-otp-${OTP_MAJOR}/bin
echo "path=${OTP_PATH}:${ELIXIR_PATH}" >> $GITHUB_OUTPUT
echo "${OTP_PATH}" >> $GITHUB_PATH
echo "${ELIXIR_PATH}" >> $GITHUB_PATH
- name: Compile and check warnings
run: |
export PATH="${{ steps.setup.outputs.path }}:${PATH}"
mix local.hex --force
mix local.rebar --force
mix deps.get
mix compile --warnings-as-errors
- name: Run epmd for distributed tests
run: |
export PATH="${{ steps.setup.outputs.path }}:${PATH}"
epmd -daemon
- name: Run tests
run: |
export PATH="${{ steps.setup.outputs.path }}:${PATH}"
if [ "${{ matrix.job.build }}" = "true" ]; then
export LIBMLX_BUILD=true
fi
mix test --warnings-as-errors
macos:
name: macOS (${{ matrix.job.elixir }}, ${{ matrix.job.otp }})
runs-on: macos-14
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

EMLX is the Nx Backend for the [MLX](https://github.com/ml-explore/mlx) library.

Because of MLX's nature, EMLX is only supported on macOS.
Because of MLX's nature, EMLX with GPU backend is only supported on macOS.

MLX with CPU backend is available on most mainstream platforms, however, the CPU backend may not be as optimized as the GPU backend,
especially for non-macOS OSes, as they're not prioritized for development. Right now, EMLX supports x86_64 and arm64 architectures
on both macOS and Linux.

The M-Series Macs have an unified memory architecture, which allows for more passing data between the CPU and GPU to be effectively a no-op.

Expand Down
1 change: 1 addition & 0 deletions c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <map>
#include <numeric>
#include <string>
#include <cstring>

using namespace mlx::core;

Expand Down
161 changes: 151 additions & 10 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ defmodule EMLX.MixProject do
use Mix.Project

@app :emlx
@version "0.1.1-dev"
@version "0.1.2"
@mlx_version "0.22.1"

require Logger

def project do
libmlx_config = libmlx_config()

Expand Down Expand Up @@ -74,7 +76,115 @@ defmodule EMLX.MixProject do
end
end

defp libmlx_config() do
defp current_target_from_env do
arch = System.get_env("TARGET_ARCH")
os = System.get_env("TARGET_OS")
abi = System.get_env("TARGET_ABI")

if !Enum.all?([arch, os, abi], &Kernel.is_nil/1) do
"#{arch}-#{os}-#{abi}"
end
end

defp current_target! do
case current_target() do
{:ok, target} ->
target

{:error, reason} ->
Mix.raise(reason)
end
end

defp current_target do
current_target_from_env = current_target_from_env()

if current_target_from_env do
# overwrite current target triplet from environment variables
{:ok, current_target_from_env}
else
current_target(:os.type())
end
end

defp current_target({:win32, _}) do
processor_architecture =
String.downcase(String.trim(System.get_env("PROCESSOR_ARCHITECTURE")))

# https://docs.microsoft.com/en-gb/windows/win32/winprog64/wow64-implementation-details?redirectedfrom=MSDN
partial_triplet =
case processor_architecture do
"amd64" ->
"x86_64-windows-"

"ia64" ->
"ia64-windows-"

"arm64" ->
"aarch64-windows-"

"x86" ->
"x86-windows-"
end

{compiler, _} = :erlang.system_info(:c_compiler_used)

case compiler do
:msc ->
{:ok, partial_triplet <> "msvc"}

:gnuc ->
{:ok, partial_triplet <> "gnu"}

other ->
{:ok, partial_triplet <> Atom.to_string(other)}
end
end

defp current_target({:unix, _}) do
# get current target triplet from `:erlang.system_info/1`
system_architecture = to_string(:erlang.system_info(:system_architecture))
current = String.split(system_architecture, "-", trim: true)

case length(current) do
4 ->
{:ok, "#{Enum.at(current, 0)}-#{Enum.at(current, 2)}-#{Enum.at(current, 3)}"}

3 ->
case :os.type() do
{:unix, :darwin} ->
current =
if "aarch64" == Enum.at(current, 0) do
["arm64" | tl(current)]
else
current
end

# could be something like aarch64-apple-darwin21.0.0
# but we don't really need the last 21.0.0 part
if String.match?(Enum.at(current, 2), ~r/^darwin.*/) do
{:ok, "#{Enum.at(current, 0)}-#{Enum.at(current, 1)}-darwin"}
else
{:ok, system_architecture}
end

_ ->
{:ok, system_architecture}
end

_ ->
{:error, "Cannot determine current target"}
end
end

@supported_targets [
"x86_64-apple-darwin",
"arm64-apple-darwin",
"x86_64-linux-gnu",
"aarch64-linux-gnu",
"riscv64-linux-gnu"
]
defp libmlx_config do
version = System.get_env("LIBMLX_VERSION", @mlx_version)

features = %{
Expand All @@ -85,16 +195,42 @@ defmodule EMLX.MixProject do

variant = to_variant(features)

current_target = current_target!()

cache_dir =
if dir = System.get_env("LIBMLX_CACHE") do
Path.expand(dir)
else
:filename.basedir(:user_cache, "libmlx")
end

libmlx_archive =
Path.join(
cache_dir,
"libmlx-#{version}-#{current_target}#{variant}.tar.gz"
)

libmlx_archive = System.get_env("MLX_ARCHIVE_PATH", libmlx_archive)

features =
if not Enum.member?(@supported_targets, current_target) and
is_nil(System.get_env("MLX_ARCHIVE_PATH")) do
Logger.warning("""
Current target #{current_target} is not officially supported by EMLX, will fallback to building from source.
A prebuilt libmlx archive for this target can be specified by setting the environment variable MLX_ARCHIVE_PATH to the path of the archive.
""")

%{features | build?: true}
else
features
end

%{
target: current_target,
libmlx_archive: libmlx_archive,
version: version,
dir: Path.join(cache_dir, "libmlx-#{version}#{variant}"),
dir: Path.join(cache_dir, "libmlx-#{version}-#{current_target}#{variant}"),
features: features,
variant: variant,
cache_dir: cache_dir
Expand Down Expand Up @@ -141,13 +277,10 @@ defmodule EMLX.MixProject do
defp download_and_unarchive(cache_dir, libmlx_config) do
File.mkdir_p!(cache_dir)

libmlx_archive =
Path.join(cache_dir, "libmlx-#{libmlx_config.version}#{libmlx_config.variant}.tar.gz")

libmlx_archive = System.get_env("MLX_ARCHIVE_PATH", libmlx_archive)
libmlx_archive = libmlx_config.libmlx_archive

url =
"https://github.com/cocoa-xu/mlx-build/releases/download/v#{libmlx_config.version}/mlx-arm64-apple-darwin#{libmlx_config.variant}.tar.gz"
"https://github.com/cocoa-xu/mlx-build/releases/download/v#{libmlx_config.version}/mlx-#{libmlx_config.target}#{libmlx_config.variant}.tar.gz"

sha256_url = "#{url}.sha256"

Expand All @@ -164,10 +297,18 @@ defmodule EMLX.MixProject do
unless File.exists?(libmlx_archive) do
# Download libmlx

if {:unix, :darwin} != :os.type() do
Mix.raise("EMLX only supports macOS for now")
case :os.type() do
{:unix, :darwin} ->
:ok

{:unix, _} ->
Logger.warning("MLX only has CPU backend available for current target")

_ ->
Mix.raise("EMLX only supports macOS and x86_64, aarch64 and riscv64 Linux for now")
end

Mix.shell().info("Downloading libmlx from #{url}")
download!(url, libmlx_archive)
:ok = maybe_verify_integrity!(verify_integrity, libmlx_archive)
end
Expand Down
19 changes: 18 additions & 1 deletion test/emlx/nx_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@ defmodule EMLX.Nx.DoctestTest do
:ok
end

@os_specific_rounding_error (case(:os.type()) do
{:unix, :darwin} ->
[]

{:unix, _} ->
[
# x86_64 and aarch64
atanh: 1,
# aarch64
ifft: 2
]

_ ->
[]
end)

@rounding_error [
exp: 1,
erf: 1,
Expand Down Expand Up @@ -56,5 +72,6 @@ defmodule EMLX.Nx.DoctestTest do
sort: 2
]

doctest Nx, except: @rounding_error ++ @not_supported ++ @to_be_fixed
doctest Nx,
except: @rounding_error ++ @os_specific_rounding_error ++ @not_supported ++ @to_be_fixed
end
12 changes: 11 additions & 1 deletion test/emlx/nx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,17 @@ defmodule EMLX.NxTest do
non_finite =
Nx.stack([Nx.Constants.infinity(), Nx.Constants.nan(), Nx.Constants.neg_infinity()])

for to_type <- [u: 8, s: 16, s: 32] do
skip_infinity_to_u8 =
case :os.type() do
{:unix, subtype} when subtype != :darwin ->
[arch | _] = String.split(to_string(:erlang.system_info(:system_architecture)), "-")
arch == "x86_64"

_ ->
false
end

for to_type <- [u: 8, s: 16, s: 32], !(to_type == {:u, 8} and skip_infinity_to_u8) do
actual = Nx.as_type(non_finite, to_type) |> Nx.backend_transfer()
expected = Nx.backend_copy(non_finite, Nx.BinaryBackend) |> Nx.as_type(to_type)
assert actual == expected
Expand Down

0 comments on commit 6afc1a7

Please sign in to comment.