Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create blocked Jacobi method for eigen decomposition #1510

Merged
merged 18 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1240,25 +1240,6 @@ defmodule Nx.BinaryBackend do
output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end)
end

@impl true
def eigh(
{%{type: output_type} = eigenvals_holder, eigenvecs_holder},
%{type: input_type, shape: input_shape} = tensor,
opts
) do
bin = to_binary(tensor)
rank = tuple_size(input_shape)
n = elem(input_shape, rank - 1)

{eigenvals, eigenvecs} =
bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>}, fn matrix, {vals_acc, vecs_acc} ->
{vals, vecs} = B.Matrix.eigh(matrix, input_type, {n, n}, output_type, opts)
{vals_acc <> vals, vecs_acc <> vecs}
end)

{from_binary(eigenvals_holder, eigenvals), from_binary(eigenvecs_holder, eigenvecs)}
end

@impl true
def lu(
{%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder},
Expand Down
272 changes: 0 additions & 272 deletions nx/lib/nx/binary_backend/matrix.ex
Original file line number Diff line number Diff line change
Expand Up @@ -116,150 +116,6 @@ defmodule Nx.BinaryBackend.Matrix do

defp do_ts([], [], _idx, acc), do: acc

defp qr_decomposition(matrix, n, _eps) when n in 0..1 do
{[[1.0]], matrix}
end

defp qr_decomposition(matrix, n, eps) when n >= 2 do
# QR decomposition is performed by using Householder transform
# this function originally supported generic QR, but
# it is now only used by eigh. Because of this,
# we simplified the function signature to only
# support square matrices.

{q_matrix, r_matrix} =
for i <- 0..(n - 2)//1, reduce: {nil, matrix} do
{q, r} ->
h =
r
|> slice_matrix([i, i], [n - i, 1])
|> householder_reflector(n, eps)

# If we haven't allocated Q yet, let Q = H1
# TODO: Resolve inconsistent with the Householder reflector.
# cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
q =
if is_nil(q) do
h
else
dot_matrix_real(q, h)
end

r = dot_matrix_real(h, r)
{q, r}
end

{approximate_zeros(q_matrix, eps), approximate_zeros(r_matrix, eps)}
end

defp raise_not_hermitian do
raise ArgumentError,
"matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)"
end

def eigh(input_data, input_type, {n, n} = input_shape, output_type, opts) do
eps = opts[:eps]
max_iter = opts[:max_iter]

# Validate that the input is a Hermitian matrix using the relation A^* = A.
a = binary_to_matrix(input_data, input_type, input_shape)

is_hermitian =
a
|> transpose_matrix()
|> Enum.map(fn a_row -> Enum.map(a_row, &Complex.conjugate(&1)) end)
|> is_approximately_same?(a, eps)

unless is_hermitian do
raise_not_hermitian()
end

# Hessenberg decomposition
{h, q_h} = hessenberg_decomposition(a, n, eps)

# QR iteration for eigenvalues and eigenvectors
{eigenvals_diag, eigenvecs} =
Enum.reduce_while(1..max_iter//1, {h, q_h}, fn _, {a_old, q_old} ->
# QR decomposition
{q_now, r_now} = qr_decomposition(a_old, n, eps)

# Update matrix A, Q
a_new = dot_matrix_real(r_now, q_now)
q_new = dot_matrix_real(q_old, q_now)

if is_approximately_same?(q_old, q_new, eps) do
{:halt, {a_new, q_new}}
else
{:cont, {a_new, q_new}}
end
end)

# Obtain the eigenvalues, which are the diagonal elements
indices_diag = for idx <- 0..(n - 1), do: [idx, idx]
eigenvals = get_matrix_elements(eigenvals_diag, indices_diag)

# In general, the eigenvalues of a Hermitian matrix are real numbers
eigenvals_real = eigenvals |> Enum.map(&Complex.real(&1))

# Reduce the elements smaller than eps to zero
{eigenvals_real |> approximate_zeros(eps) |> matrix_to_binary(output_type),
eigenvecs |> approximate_zeros(eps) |> matrix_to_binary(output_type)}
end

defp hessenberg_decomposition(matrix, n, _eps) when n in 0..1 do
{matrix, [[1.0]]}
end

defp hessenberg_decomposition(matrix, n, eps) do
# Hessenberg decomposition is performed by using Householder transform
{hess_matrix, q_matrix} =
for i <- 0..(n - 2)//1, reduce: {matrix, nil} do
{hess, q} ->
h =
hess
|> slice_matrix([i + 1, i], [n - i - 1, 1])
|> householder_reflector(n, eps)

# If we haven't allocated Q yet, let Q = H1
# TODO: Resolve inconsistent with the Householder reflector.
# cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
q =
if is_nil(q) do
h
else
dot_matrix_real(q, h)
end

# Hessenberg matrix H updating
h_adj = adjoint_matrix(h)

hess =
h
|> dot_matrix_real(hess)
|> dot_matrix_real(h_adj)

{hess, q}
end

{approximate_zeros(hess_matrix, eps), approximate_zeros(q_matrix, eps)}
end

defp is_approximately_same?(a, b, eps) do
# Determine if matrices `a` and `b` are equal in the range of eps
a
|> Enum.zip(b)
|> Enum.all?(fn {a_row, b_row} ->
a_row
|> Enum.zip(b_row)
|> Enum.all?(fn
{a_elem, b_elem} ->
abs_diff = Complex.abs(a_elem - b_elem)

abs_diff == :nan or abs_diff <= eps
end)
end)
end

def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do
a = binary_to_matrix(input_data, input_type, input_shape)
eps = opts[:eps]
Expand Down Expand Up @@ -361,116 +217,6 @@ defmodule Nx.BinaryBackend.Matrix do
end)
end

## Householder helpers

defp householder_reflector(a, target_k, eps)

defp householder_reflector([], target_k, _eps) do
flat_list =
for col <- 0..(target_k - 1), row <- 0..(target_k - 1), into: [] do
if col == row, do: 1, else: 0
end

Enum.chunk_every(flat_list, target_k)
end

defp householder_reflector(a, target_k, eps) do
{v, scale, is_complex} = householder_reflector_pivot(a, eps)

prefix_threshold = target_k - length(v)
v = List.duplicate(0, prefix_threshold) ++ v

# dot(v, v) = norm_v_squared, which can be calculated from norm_a as:
# norm_v_squared = norm_a_squared - a_0^2 + v_0^2

# execute I - 2 / norm_v_squared * outer(v, v)
{_, _, reflector_reversed} =
for col_factor <- v, row_factor <- v, reduce: {0, 0, []} do
{row, col, acc} ->
row_factor = if is_complex, do: Complex.conjugate(row_factor), else: row_factor

# The current element in outer(v, v) is given by col_factor * row_factor
# and the current I element is 1 when row == col
identity_element = if row == col, do: 1, else: 0

result =
if row >= prefix_threshold and col >= prefix_threshold do
identity_element -
scale * col_factor * row_factor
else
identity_element
end

acc = [result | acc]

if col + 1 == target_k do
{row + 1, 0, acc}
else
{row, col + 1, acc}
end
end

# This is equivalent to reflector_reversed |> Enum.reverse() |> Enum.chunk_every(target_k)
{reflector, _, _} =
for x <- reflector_reversed, reduce: {[], [], 0} do
{result_acc, row_acc, col} ->
row_acc = [x | row_acc]

if col + 1 == target_k do
{[row_acc | result_acc], [], 0}
else
{result_acc, row_acc, col + 1}
end
end

reflector
end

defp householder_reflector_pivot([a_0 | tail] = a, eps) when is_number(a_0) do
# This is a trick so we can both calculate the norm of a_reverse and extract the
# head a the same time we reverse the array
# receives a_reverse as a list of numbers and returns the reflector as a
# k x k matrix

norm_a_squared = Enum.reduce(a, 0, fn x, acc -> x * Complex.conjugate(x) + acc end)
norm_a_sq_1on = norm_a_squared - a_0 * a_0

if norm_a_sq_1on < eps do
{[1 | tail], 0, false}
else
v_0 =
if a_0 <= 0 do
a_0 - Complex.sqrt(norm_a_squared)
else
-norm_a_sq_1on / (a_0 + Complex.sqrt(norm_a_squared))
end

v_0_sq = v_0 * v_0
scale = 2 * v_0_sq / (norm_a_sq_1on + v_0_sq)
v = [1 | Enum.map(tail, &(&1 / v_0))]
{v, scale, false}
end
end

defp householder_reflector_pivot([a_0 | tail], _eps) do
# complex case
norm_a_sq_1on = Enum.reduce(tail, 0, &(Complex.abs_squared(&1) + &2))
norm_a_sq = norm_a_sq_1on + Complex.abs_squared(a_0)
norm_a = Complex.sqrt(norm_a_sq)

phase_a_0 = Complex.phase(a_0)
alfa = Complex.exp(Complex.new(0, phase_a_0)) * norm_a

# u = x - alfa * e1
u_0 = a_0 + alfa
u = [u_0 | tail]
norm_u_sq = norm_a_sq_1on + Complex.abs_squared(u_0)
norm_u = Complex.sqrt(norm_u_sq)

v = Enum.map(u, &(&1 / norm_u))
{v, 2, true}
end

## Matrix (2-D array) manipulation

defp dot_matrix([], _), do: 0
Expand All @@ -491,24 +237,6 @@ defmodule Nx.BinaryBackend.Matrix do
end)
end

defp dot_matrix_real(m1, m2) do
Enum.map(m1, fn row ->
m2
|> transpose_matrix()
|> Enum.map(fn col ->
Enum.zip_reduce(row, col, 0, fn x, y, acc -> acc + x * y end)
end)
end)
end

defp adjoint_matrix([x | _] = m) when not is_list(x) do
Enum.map(m, &[Complex.conjugate(&1)])
end

defp adjoint_matrix(m) do
Enum.zip_with(m, fn cols -> Enum.map(cols, &Complex.conjugate/1) end)
end

defp transpose_matrix([x | _] = m) when not is_list(x) do
Enum.map(m, &[&1])
end
Expand Down
Loading
Loading