diff --git a/examples/cifar10.exs b/examples/cifar10.exs index ca728886..24123c52 100644 --- a/examples/cifar10.exs +++ b/examples/cifar10.exs @@ -1,12 +1,16 @@ Mix.install([ - {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}, + {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"}, {:exla, github: "elixir-nx/exla", sparse: "exla"}, {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}, - {:scidata, "~> 0.1.1"}, + {:scidata, "~> 0.1.1"} ]) +# Configure default platform with accelerator precedence as tpu > cuda > rocm > host +EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host]) + defmodule Cifar do require Axon + alias Axon.Loop.State defp transform_images({bin, type, shape}) do bin @@ -14,6 +18,7 @@ defmodule Cifar do |> Nx.reshape({elem(shape, 0), 3, 32, 32}) |> Nx.divide(255.0) |> Nx.to_batched_list(32) + |> Enum.split(1500) end defp transform_labels({bin, type, _}) do @@ -22,15 +27,7 @@ defmodule Cifar do |> Nx.new_axis(-1) |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) |> Nx.to_batched_list(32) - end - - defp view_images(images, {start_index, len}) do - images - |> hd() - |> Nx.slice_axis(start_index, len, 0) - |> Nx.mean(axes: [1], keep_axes: true) - |> Nx.to_heatmap() - |> IO.inspect + |> Enum.split(1500) end defp build_model(input_shape) do @@ -39,43 +36,77 @@ defmodule Cifar do |> Axon.batch_norm() |> Axon.max_pool(kernel_size: {2, 2}) |> Axon.conv(64, kernel_size: {3, 3}, activation: :relu) - |> Axon.spatial_dropout() |> Axon.batch_norm() |> Axon.max_pool(kernel_size: {2, 2}) - |> Axon.conv(32, kernel_size: {3, 3}, activation: :relu) - |> Axon.batch_norm() |> Axon.flatten() |> Axon.dense(64, activation: :relu) - |> Axon.dropout() + |> Axon.dropout(rate: 0.5) |> Axon.dense(10, activation: :softmax) end - defp train_model(model, {train_images, train_labels}, epochs) do + defp log_metrics( + %State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state, + mode + ) do + loss = + case mode do + :train -> + %{loss: loss} = pstate + "Loss: #{:io_lib.format('~.5f', [Nx.to_scalar(loss)])}" + + :test -> + "" + end + + metrics = + metrics + |> Enum.map(fn {k, v} -> "#{k}: #{:io_lib.format('~.5f', [Nx.to_scalar(v)])}" end) + |> Enum.join(" ") + + IO.write("\rEpoch: #{Nx.to_scalar(epoch)}, Batch: #{Nx.to_scalar(iter)}, #{loss} #{metrics}") + + {:continue, state} + end + + defp train_model(model, train_images, train_labels, epochs) do + model + |> Axon.Loop.trainer(:categorical_cross_entropy, :adam) + |> Axon.Loop.metric(:accuracy, "Accuracy") + |> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :train), every: 50) + |> Axon.Loop.run(Stream.zip(train_images, train_labels), epochs: epochs, compiler: EXLA) + end + + defp test_model(model, model_state, test_images, test_labels) do model - |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.sgd(0.01), metrics: [:accuracy]) - |> Axon.Training.train(train_images, train_labels, epochs: epochs, compiler: EXLA) - |> Nx.backend_transfer() + |> Axon.Loop.evaluator(model_state) + |> Axon.Loop.metric(:accuracy, "Accuracy") + |> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :test), every: 50) + |> Axon.Loop.run(Stream.zip(test_images, test_labels), compiler: EXLA) end def run do - {train_images, train_labels} = Scidata.CIFAR10.download(transform_images: &transform_images/1, transform_labels: &transform_labels/1) + {images, labels} = + Scidata.CIFAR10.download( + transform_images: &transform_images/1, + transform_labels: &transform_labels/1 + ) + + {train_images, test_images} = images + {train_labels, test_labels} = labels - view_images(train_images, {0, 1}) + model = build_model({nil, 3, 32, 32}) |> IO.inspect() - model = build_model({nil, 3, 32, 32}) |> IO.inspect + IO.write("\n\n Training Model \n\n") - final_training_state = + model_state = model - |> train_model({train_images, train_labels}, 20) - |> IO.inspect() + |> train_model(train_images, train_labels, 10) - test_images = train_images |> hd() |> Nx.slice_axis(10, 3, 0) - view_images(train_images, {10, 3}) + IO.write("\n\n Testing Model \n\n") - model - |> Axon.predict(final_training_state[:params], test_images) - |> Nx.argmax(axis: -1) - |> IO.inspect + test_model(model, model_state, test_images, test_labels) + + IO.write("\n\n") end end diff --git a/examples/fashionmnist_autoencoder.exs b/examples/fashionmnist_autoencoder.exs index f3a5c49d..f9967078 100644 --- a/examples/fashionmnist_autoencoder.exs +++ b/examples/fashionmnist_autoencoder.exs @@ -1,12 +1,16 @@ Mix.install([ - {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}, + {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"}, {:exla, github: "elixir-nx/exla", sparse: "exla"}, {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}, {:scidata, "~> 0.1.1"}, ]) +# Configure default platform with accelerator precedence as tpu > cuda > rocm > host +EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host]) + defmodule Fashionmist do require Axon + alias Axon.Loop.State defmodule Autoencoder do defp encoder(x, latent_dim) do @@ -36,10 +40,36 @@ defmodule Fashionmist do |> Nx.to_batched_list(32) end + defp log_metrics( + %State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state, + mode + ) do + loss = + case mode do + :train -> + %{loss: loss} = pstate + "Loss: #{:io_lib.format('~.5f', [Nx.to_scalar(loss)])}" + + :test -> + "" + end + + metrics = + metrics + |> Enum.map(fn {k, v} -> "#{k}: #{:io_lib.format('~.5f', [Nx.to_scalar(v)])}" end) + |> Enum.join(" ") + + IO.write("\rEpoch: #{Nx.to_scalar(epoch)}, Batch: #{Nx.to_scalar(iter)}, #{loss} #{metrics}") + + {:continue, state} + end + defp train_model(model, train_images, epochs) do model - |> Axon.Training.step(:mean_squared_error, Axon.Optimizers.adam(0.01), metrics: [:mean_absolute_error]) - |> Axon.Training.train(train_images, train_images, epochs: epochs, compiler: EXLA) + |> Axon.Loop.trainer(:mean_squared_error, :adam) + |> Axon.Loop.metric(:mean_absolute_error, "Error") + |> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :train), every: 50) + |> Axon.Loop.run(Stream.zip(train_images, train_images), epochs: epochs, compiler: EXLA) end def run do @@ -47,7 +77,7 @@ defmodule Fashionmist do model = Autoencoder.build_model({nil, 1, 28, 28}, 64) |> IO.inspect - final_training_state = train_model(model, train_images, 5) + model_state = train_model(model, train_images, 5) sample_image = train_images @@ -58,7 +88,7 @@ defmodule Fashionmist do sample_image |> Nx.to_heatmap() |> IO.inspect model - |> Axon.predict(final_training_state[:params], sample_image, compiler: EXLA) + |> Axon.predict(model_state, sample_image, compiler: EXLA) |> Nx.to_heatmap() |> IO.inspect() end diff --git a/examples/mnist.exs b/examples/mnist.exs index bc1e09de..4cf9aa39 100644 --- a/examples/mnist.exs +++ b/examples/mnist.exs @@ -1,19 +1,26 @@ Mix.install([ - {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}, - {:exla, github: "elixir-nx/exla", sparse: "exla"}, + {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"}, + {:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"}, {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}, - {:scidata, "~> 0.1.1"}, + {:scidata, "~> 0.1.1"} ]) +# Configure default platform with accelerator precedence as tpu > cuda > rocm > host +EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host]) + defmodule Mnist do require Axon + alias Axon.Loop.State + defp transform_images({bin, type, shape}) do bin |> Nx.from_binary(type) |> Nx.reshape({elem(shape, 0), 784}) |> Nx.divide(255.0) |> Nx.to_batched_list(32) + # Test split + |> Enum.split(1750) end defp transform_labels({bin, type, _}) do @@ -22,15 +29,8 @@ defmodule Mnist do |> Nx.new_axis(-1) |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) |> Nx.to_batched_list(32) - end - - defp view_images(images, {start_index, len}) do - images - |> hd() - |> Nx.slice_axis(start_index, len, 0) - |> Nx.reshape({:auto, 28, 28}) - |> Nx.to_heatmap() - |> IO.inspect + # Test split + |> Enum.split(1750) end defp build_model(input_shape) do @@ -40,30 +40,70 @@ defmodule Mnist do |> Axon.dense(10, activation: :softmax) end - defp train_model(model, {train_images, train_labels}, epochs) do + defp log_metrics( + %State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state, + mode + ) do + loss = + case mode do + :train -> + %{loss: loss} = pstate + "Loss: #{:io_lib.format('~.5f', [Nx.to_scalar(loss)])}" + + :test -> + "" + end + + metrics = + metrics + |> Enum.map(fn {k, v} -> "#{k}: #{:io_lib.format('~.5f', [Nx.to_scalar(v)])}" end) + |> Enum.join(" ") + + IO.write("\rEpoch: #{Nx.to_scalar(epoch)}, Batch: #{Nx.to_scalar(iter)}, #{loss} #{metrics}") + + {:continue, state} + end + + defp train_model(model, train_images, train_labels, epochs) do model - |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005), metrics: [:accuracy]) - |> Axon.Training.train(train_images, train_labels, epochs: epochs, compiler: EXLA, log_every: 100) + |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005)) + |> Axon.Loop.metric(:accuracy, "Accuracy") + |> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :train), every: 50) + |> Axon.Loop.run(Stream.zip(train_images, train_labels), epochs: epochs, compiler: EXLA) + end + + defp test_model(model, model_state, test_images, test_labels) do + model + |> Axon.Loop.evaluator(model_state) + |> Axon.Loop.metric(:accuracy, "Accuracy") + |> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :test), every: 50) + |> Axon.Loop.run(Stream.zip(test_images, test_labels), compiler: EXLA) end def run do - {train_images, train_labels} = Scidata.MNIST.download(transform_images: &transform_images/1, transform_labels: &transform_labels/1) + {images, labels} = + Scidata.MNIST.download( + transform_images: &transform_images/1, + transform_labels: &transform_labels/1 + ) + + {train_images, test_images} = images + {train_labels, test_labels} = labels - view_images(train_images, {0, 1}) + model = build_model({nil, 784}) |> IO.inspect() - model = build_model({nil, 784}) |> IO.inspect + IO.write("\n\n Training Model \n\n") - final_training_state = + model_state = model - |> train_model({train_images, train_labels}, 10) + |> train_model(train_images, train_labels, 5) - test_images = train_images |> hd() |> Nx.slice_axis(10, 3, 0) - view_images(train_images, {10, 3}) + IO.write("\n\n Testing Model \n\n") model - |> Axon.predict(final_training_state[:params], test_images) - |> Nx.argmax(axis: -1) - |> IO.inspect + |> test_model(model_state, test_images, test_labels) + + IO.write("\n\n") end end diff --git a/examples/mnist_gan.exs b/examples/mnist_gan.exs index bb3c2c60..0834e426 100644 --- a/examples/mnist_gan.exs +++ b/examples/mnist_gan.exs @@ -1,172 +1,180 @@ +Mix.install([ + {:axon, github: "elixir-nx/axon"}, + {:exla, github: "elixir-nx/nx", sparse: "exla"}, + {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, + {:scidata, "~> 0.1.0"} +]) + +EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host]) + defmodule MNISTGAN do require Axon - require Axon.Updates + alias Axon.Loop.State import Nx.Defn - @default_defn_compiler {EXLA, keep_on_device: true} + defp transform_images({bin, type, shape}) do + bin + |> Nx.from_binary(type) + |> Nx.reshape({elem(shape, 0), 1, 28, 28}) + |> Nx.divide(255.0) + |> Nx.to_batched_list(32) + end - def generator do - Axon.input({nil, 100}) - |> Axon.dense(256, activation: :leaky_relu) + defp build_generator(z_dim) do + Axon.input({nil, z_dim}) + |> Axon.dense(256) + |> Axon.relu() |> Axon.batch_norm() - |> Axon.dense(512, activation: :leaky_relu) + |> Axon.dense(512) + |> Axon.relu() |> Axon.batch_norm() - |> Axon.dense(1024, activation: :leaky_relu) + |> Axon.dense(1024) + |> Axon.relu() |> Axon.batch_norm() - |> Axon.dense(784, activation: :tanh) + |> Axon.dense(784) + |> Axon.tanh() + |> Axon.reshape({1, 28, 28}) end - def discriminator do - Axon.input({nil, 28, 28}) + defp build_discriminator(input_shape) do + Axon.input(input_shape) |> Axon.flatten() - |> Axon.dense(512, activation: :tanh) - |> Axon.dense(256, activation: :tanh) + |> Axon.dense(512) + |> Axon.relu() + |> Axon.dense(256) + |> Axon.relu() |> Axon.dense(2, activation: :softmax) end - defn generate(params, latent) do - Axon.predict(generator(), params, latent) + defnp running_average(avg, obs, i) do + avg + |> Nx.multiply(i) + |> Nx.add(obs) + |> Nx.divide(Nx.add(i, 1)) end - defn d_loss(d_params, images, targets) do - preds = Axon.predict(discriminator(), d_params, images) - Axon.Losses.categorical_cross_entropy(preds, targets, reduction: :mean) + defn init(d_model, g_model, init_optim_d, init_optim_g) do + d_params = Axon.init(d_model) + g_params = Axon.init(g_model) + + %{ + iteration: Nx.tensor(0), + discriminator: %{ + model_state: d_params, + optimizer_state: init_optim_d.(d_params), + loss: Nx.tensor(0.0) + }, + generator: %{ + model_state: g_params, + optimizer_state: init_optim_g.(g_params), + loss: Nx.tensor(0.0) + } + } end - defn update_d(params, d_optim_state, images, targets, update_fn) do - gradients = grad(params, &d_loss(&1, images, targets)) - {updates, new_optim_state} = update_fn.(gradients, d_optim_state, params) - {Axon.Updates.apply_updates(params, updates), new_optim_state} - end + defn batch_step(d_model, g_model, optim_d, optim_g, real_images, state) do - defn g_loss(g_params, d_params, latent) do - valid = Nx.iota({32, 2}, axis: 1, type: {:u, 8}) - g_preds = Axon.predict(generator(), g_params, latent) - d_loss(d_params, g_preds, valid) - end + iter = state[:iteration] + d_params = state[:discriminator][:model_state] + g_params = state[:generator][:model_state] - defn update_g(g_params, g_optim_state, d_params, update_fn, latent) do - gradients = grad(g_params, &g_loss(&1, d_params, latent)) + # Update D + fake_labels = Nx.iota({32, 2}, axis: 1) + real_labels = Nx.reverse(fake_labels) + noise = Nx.random_normal({32, 100}) - {updates, new_optim_state} = update_fn.(gradients, g_optim_state, g_params) - {Axon.Updates.apply_updates(g_params, updates), new_optim_state} - end + {d_loss, d_grads} = value_and_grad(d_params, fn params -> + fake_images = Axon.predict(g_model, g_params, noise, mode: :train) - def update(g_params, g_optim_state, d_params, d_optim_state, update_fn, images) do - valid = Nx.iota({32, 2}, axis: 1, type: {:u, 8}) - fake = Nx.iota({32, 2}, axis: 1, type: {:u, 8}) |> Nx.reverse(axes: [1]) + d_fake_preds = Axon.predict(d_model, params, fake_images, mode: :train) + d_real_preds = Axon.predict(d_model, params, real_images, mode: :train) - latent = Nx.random_normal({32, 100}) + joint_preds = Nx.concatenate([d_fake_preds, d_real_preds], axis: 0) + joint_labels = Nx.concatenate([fake_labels, real_labels], axis: 0) - fake_images = - g_params - |> generate(latent) - |> Nx.reshape({32, 28, 28}) - - {new_d_params, new_d_state} = - d_params - |> update_d(d_optim_state, images, valid, update_fn) - - {new_d_params, new_d_state} = - new_d_params - |> update_d(new_d_state, fake_images, fake, update_fn) - - {new_g_params, new_g_state} = - g_params - |> update_g(g_optim_state, new_d_params, update_fn, latent) + Axon.Losses.categorical_cross_entropy(joint_labels, joint_preds, reduction: :mean) + end) - {new_g_params, new_g_state, new_d_params, new_d_state} - end + d_optimizer_state = state[:discriminator][:optimizer_state] - def train_epoch(g_params, g_state, d_params, d_state, update_fn, imgs) do - imgs - |> Enum.with_index() - |> Enum.reduce({g_params, g_state, d_params, d_state}, fn - {imgs, i}, {g_params, g_state, d_params, d_state} -> - {new_g, g_state, new_d, d_state} = - update(g_params, g_state, d_params, d_state, update_fn, imgs) + {d_updates, d_optimizer_state} = optim_d.(d_grads, d_optimizer_state, d_params) + d_params = Axon.Updates.apply_updates(d_params, d_updates) - IO.write("\rBatch: #{i}") + # Update G + {g_loss, g_grads} = value_and_grad(g_params, fn params -> + fake_images = Axon.predict(g_model, params, noise, mode: :train) - if rem(i, 50) == 0 do - latent = Nx.random_normal({1, 100}) - IO.inspect Nx.to_heatmap generate(new_g, latent) |> Nx.reshape({1, 28, 28}) - end + d_preds = Axon.predict(d_model, d_params, fake_images) - {new_g, g_state, new_d, d_state} + Axon.Losses.categorical_cross_entropy(real_labels, d_preds, reduction: :mean) end) - end - - def train(imgs, g_params, g_state, d_params, d_state, update_fn, opts \\ []) do - epochs = opts[:epochs] || 5 - - for epoch <- 1..epochs, reduce: {g_params, g_state, d_params, d_state} do - {g_params, g_state, d_params, d_state} -> - {time, {new_g_params, new_g_state, new_d_params, new_d_state}} = - :timer.tc(__MODULE__, :train_epoch, [g_params, g_state, d_params, d_state, update_fn, imgs]) - IO.puts("Epoch #{epoch} Time: #{time / 1_000_000}s") - {new_g_params, new_g_state, new_d_params, new_d_state} - end + g_optimizer_state = state[:generator][:optimizer_state] + + {g_updates, g_optimizer_state} = optim_g.(g_grads, g_optimizer_state, g_params) + g_params = Axon.Updates.apply_updates(g_params, g_updates) + + %{ + iteration: iter + 1, + discriminator: %{ + model_state: d_params, + optimizer_state: d_optimizer_state, + loss: running_average(state[:discriminator][:loss], d_loss, iter) + }, + generator: %{ + model_state: g_params, + optimizer_state: g_optimizer_state, + loss: running_average(state[:generator][:loss], g_loss, iter) + } + } end - defp unzip_cache_or_download(zip) do - base_url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' - path = Path.join("tmp", zip) - - data = - if File.exists?(path) do - IO.puts("Using #{zip} from tmp/\n") - File.read!(path) - else - IO.puts("Fetching #{zip} from https://storage.googleapis.com/cvdf-datasets/mnist/\n") - :inets.start() - :ssl.start() - - {:ok, {_status, _response, data}} = :httpc.request(:get, {base_url ++ zip, []}, [], []) - File.mkdir_p!("tmp") - File.write!(path, data) + defp train_loop(d_model, g_model) do + {init_optim_d, optim_d} = Axon.Optimizers.adam(2.0e-3, b1: 0.5) + {init_optim_g, optim_g} = Axon.Optimizers.adam(2.0e-3, b1: 0.5) - data - end + step = &batch_step(d_model, g_model, optim_d, optim_g, &1, &2) + init = fn -> init(d_model, g_model, init_optim_d, init_optim_g) end - :zlib.gunzip(data) + Axon.Loop.loop(step, init) end - def download(images) do - <<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = - unzip_cache_or_download(images) + defp log_iteration(state) do + %State{epoch: epoch, iteration: iter, step_state: pstate} = state - train_images = - images - |> Nx.from_binary({:u, 8}) - |> Nx.reshape({n_images, n_rows, n_cols}) - |> Nx.divide(255) - |> Nx.to_batched_list(32) + g_loss = "G: #{:io_lib.format('~.5f', [Nx.to_scalar(pstate[:generator][:loss])])}" + d_loss = "D: #{:io_lib.format('~.5f', [Nx.to_scalar(pstate[:discriminator][:loss])])}" - IO.puts("#{n_images} #{n_rows}x#{n_cols} images\n") + IO.write("\rEpoch: #{Nx.to_scalar(epoch)}, batch: #{Nx.to_scalar(iter)} #{g_loss} #{d_loss}") - train_images + {:continue, state} end -end - -require Axon -generator = MNISTGAN.generator() |> IO.inspect -discriminator = MNISTGAN.discriminator() |> IO.inspect + defp view_generated_images(model, batch_size, state) do + %State{step_state: pstate} = state + noise = Nx.random_normal({batch_size, 100}) + preds = Axon.predict(model, pstate[:generator][:model_state], noise, compiler: EXLA) -train_images = MNISTGAN.download('train-images-idx3-ubyte.gz') + preds + |> Nx.reshape({batch_size, 28, 28}) + |> Nx.to_heatmap() + |> IO.inspect() -IO.puts("Initializing parameters...\n") + {:continue, state} + end -{init_fn, update_fn} = Axon.Optimizers.adam(0.005) + def run() do + {images, _} = Scidata.MNIST.download(transform_images: &transform_images/1) -d_params = Axon.init(discriminator, compiler: EXLA) -d_state = Nx.Defn.jit(init_fn, [d_params], compiler: EXLA) -g_params = Axon.init(generator, compiler: EXLA) -g_state = Nx.Defn.jit(init_fn, [g_params], compiler: EXLA) + generator = build_generator(100) + discriminator = build_discriminator({nil, 1, 28, 28}) -{g_params, _d_params} = MNISTGAN.train(train_images, g_params, g_state, d_params, d_state, update_fn, epochs: 10) + discriminator + |> train_loop(generator) + |> Axon.Loop.handle(:iteration_completed, &log_iteration/1, every: 50) + |> Axon.Loop.handle(:epoch_completed, &view_generated_images(generator, 3, &1)) + |> Axon.Loop.run(images, epochs: 10, compiler: EXLA) + end +end -latent = Nx.random_uniform({1, 100}) -IO.inspect Nx.to_heatmap MNISTGAN.generator(g_params, latent) +MNISTGAN.run() \ No newline at end of file diff --git a/examples/resnet50.exs b/examples/resnet50.exs index 6a3f7fbd..975fec9d 100644 --- a/examples/resnet50.exs +++ b/examples/resnet50.exs @@ -1,5 +1,5 @@ Mix.install([ - {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}, + {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"}, ]) defmodule ResNet50 do diff --git a/examples/xor.exs b/examples/xor.exs index c2210649..18838c30 100644 --- a/examples/xor.exs +++ b/examples/xor.exs @@ -1,12 +1,12 @@ -# Normally you wouldn't do this, but this is to demonstrate -# multi input models as just using `input` many times Mix.install([ - {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"}, + {:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"}, {:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}, + {:exla, path: "../nx/exla"} ]) defmodule XOR do require Axon + alias Axon.Loop.State defp build_model(input_shape1, input_shape2) do inp1 = Axon.input(input_shape1) @@ -17,33 +17,51 @@ defmodule XOR do |> Axon.dense(1, activation: :sigmoid) end - defp build_data do - for _ <- 1..1000 do - x1 = for _ <- 1..32, do: [Enum.random(0..1)] - x2 = for _ <- 1..32, do: [Enum.random(0..1)] - {Nx.tensor(x1), Nx.tensor(x2)} - end + defp batch do + x1 = Nx.tensor(for _ <- 1..32, do: [Enum.random(0..1)]) + x2 = Nx.tensor(for _ <- 1..32, do: [Enum.random(0..1)]) + y = Nx.logical_xor(x1, x2) + {{x1, x2}, y} end - defp train_model(model, {data, targets}, epochs) do + defp log_metrics( + %State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state, + mode + ) do + loss = + case mode do + :train -> + %{loss: loss} = pstate + "Loss: #{:io_lib.format('~.5f', [Nx.to_scalar(loss)])}" + + :test -> + "" + end + + metrics = + metrics + |> Enum.map(fn {k, v} -> "#{k}: #{:io_lib.format('~.5f', [Nx.to_scalar(v)])}" end) + |> Enum.join(" ") + + IO.write("\rEpoch: #{Nx.to_scalar(epoch)}, Batch: #{Nx.to_scalar(iter)}, #{loss} #{metrics}") + + {:continue, state} + end + + defp train_model(model, data, epochs) do model - |> Axon.Training.step(:binary_cross_entropy, Axon.Optimizers.sgd(0.01)) - |> Axon.Training.train(data, targets, epochs: epochs) + |> Axon.Loop.trainer(:binary_cross_entropy, :sgd) + |> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :train), every: 50) + |> Axon.Loop.run(data, epochs: epochs, iterations: 1000) end def run do - model = build_model({:nil, 1}, {:nil, 1}) - - data = build_data() - - targets = - for {x1, x2} <- data do - Nx.logical_xor(x1, x2) - end + model = build_model({nil, 1}, {nil, 1}) + data = Stream.repeatedly(&batch/0) - final_training_state = train_model(model, {data, targets}, 10) + model_state = train_model(model, data, 10) - IO.inspect Axon.predict(model, final_training_state[:params], {Nx.tensor([[0]]), Nx.tensor([[1]])}) + IO.inspect Axon.predict(model, model_state, {Nx.tensor([[0]]), Nx.tensor([[1]])}) end end diff --git a/lib/axon.ex b/lib/axon.ex index 782ddb5f..ce77d2bb 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -72,10 +72,10 @@ defmodule Axon do IO.inspect model - final_params = + model_state = model - |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005)) - |> Axon.Training.train(train_images, train_labels, epochs: 10, compiler: EXLA) + |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005)) + |> Axon.Loop.run(train_data, epochs: 10, compiler: EXLA) """ alias __MODULE__, as: Axon @@ -1705,8 +1705,8 @@ defmodule Axon do |> Axon.dense(1000, activation: :softmax) model - |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adam(0.005)) - |> Axon.Training.train(input, targets, epochs: 10) + |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.005)) + |> Axon.Loop.run(data, epochs: 10) When compiled, frozen parameters are wrapped in `Nx.Defn.Kernel.stop_grad/1`, which zeros out the gradient with respect to the frozen parameter. Gradients @@ -1841,8 +1841,8 @@ defmodule Axon do Compiles the given model to `{init_fn, predict_fn}`. """ @doc type: :compilation - def compile(model) do - Axon.Compiler.__compile__(model) + def compile(model, opts \\ []) do + Axon.Compiler.__compile__(model, opts) end @doc """ diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 5db892f9..6d99d4ae 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -20,8 +20,9 @@ defmodule Axon.Compiler do ## Init JIT Compilation @doc false - def __compile__(graph) do - {compile_init(graph), compile_predict(graph, :train)} + def __compile__(graph, opts) do + mode = opts[:mode] || :train + {compile_init(graph), compile_predict(graph, mode)} end @doc false @@ -133,7 +134,7 @@ defmodule Axon.Compiler do @doc false def __jit_predict__(graph, caller, args, opts) do - mode = opts[:mode] || :inference + {mode, opts} = Keyword.pop(opts, :mode, :inference) fun = compile_predict(graph, mode) jit_or_apply(caller, fun, args, opts) end diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex new file mode 100644 index 00000000..2d6a066e --- /dev/null +++ b/lib/axon/loop.ex @@ -0,0 +1,1177 @@ +defmodule Axon.Loop do + @moduledoc """ + Abstraction for modeling a reduction of a dataset with an accumulated + state for a number of epochs. + + Inspired heavily by [PyTorch Ignite](https://pytorch.org/ignite/index.html). + + The main abstraction is the `%Loop{}` struct, which controls a nested + reduction of the form: + + Enum.reduce(1..max_epochs, state, fn epoch, state -> + Enum.reduce(data, state, &batch_step/2) + end) + + `data` is assumed to be an `Enumerable` or `Stream` of input data which is + handled by a processing function, `batch_step`. The purpose of the loop + abstraction is to take away much of the boilerplate used in solving machine + learning tasks. Tasks such as normalizing a dataset, hyperparameter optimization, + or training machine learning models boil down to writing one function: + + defn batch_step(batch, state) do + # ...do something with batch... + updated_state + end + + For tasks such as training a neural network, `state` will encapsulate things + such as model and optimizer state. For supervised learning tasks, `batch_step` + might look something like: + + defn batch_step({inputs, targets}, state) do + %{parameters: params, optimizer_state: optim_state} = state + + gradients = grad(params, objective_fn.(&1, inputs, targets)) + {updates, new_optim_state} = optimizer.(optim_state, params, gradients) + + new_params = apply_updates(params, updates) + + %{parameters: new_params, optimizer_state: optim_state} + end + + `batch_step` takes a batch of `{input, target}` pairs and the current state, + and updates the model parameters based on the gradients received from some arbitrary + objective function. This function will run in a nested loop, iterating over the entire + dataset for `N` epochs before finally returning the trained model state. By defining + 1 function, we've created a training loop that works for most machine learning models. + + In actuality, the loop abstraction accumulates a struct, `Axon.Loop.State`, which looks + like (assuming `container` is a generic Elixir container of tensors, e.g. map, tuple, etc.): + + %State{ + epoch: tensor(), + max_epoch: tensor(), + iteration: tensor(), + max_iteration: tensor(), + metrics: map(string(), container()), + times: list(number()), + step_state: container() + } + + `batch_step` takes in the batch and the step state field and returns a `step_state`, + which is a generic container of state accumulated at each iteration. The rest of the fields + in the state struct are updated automatically behind the scenes. + + The loop must start from some initial step state, thus most tasks must also provide + an additional initialization function to provide some starting point for the step + state. For machine learning tasks, the initialization function will return things like + initial model parameters and optimizer state. + + Typically, the final output of the loop is the accumulated final state; however, you + may optionally apply an output transform to extract specific values at the end of the + loop. For example, `Axon.Loop.trainer/4` by default extracts trained model state: + + output_transform = fn state -> + state.step_state[:model_state] + end + + ## Initialize and Step + + The core of the Axon loop are the init and step functions. The initialization is an + arity-0 function which provides an initial step state: + + init = fn -> + %{params: Axon.init(model)} + end + + While the step function is the `batch_step` function mentioned earlier: + + step = fn data, state -> + new_state = # ...do something... + new_state + end + + ## Metrics + + Often times you want to compute metrics assosciated with your training iterations. + To accomplish this, you can attach metrics to each `Axon.Loop`. Assuming a `batch_step` + function which looks like: + + defn batch_step({inputs, targets}, state) do + %{parameters: params, optimizer_state: optim_state} = state + + gradients = grad(params, objective_fn.(&1, inputs, targets)) + {updates, new_optim_state} = optimizer.(optim_state, params, gradients) + + new_params = apply_updates(params, updates) + + # Shown for simplicity, you can optimize this by calculating preds + # along with the gradient calculation + preds = model_fn.(params, inputs) + + %{ + y_true: targets, + y_pred: preds, + parameters: new_params, + optimizer_state: optim_state + } + end + + You can attach metrics to this by using `Axon.Loop.metric/4`: + + Axon.Loop.loop(&batch_step/2) + |> Axon.Loop.metric("Accuracy", :accuracy, fn %{y_true: y_, y_pred: y} -> [y_, y] end) + |> Axon.Loop.run(data) + + Because metrics work directly on `step_state`, you typically need to provide an output + transform to indicate which values should be passed to your metric function. By default, + Axon assumes a supervised training task with the fields `:y_true` and `:y_pred` present + in the step state. See `Axon.Loop.metric/4` for more information. + + Metrics will be tracked in the loop state using the user-provided key. Metrics integrate + seamlessly with the supervised metrics defined in `Axon.Metrics`. You can also use metrics + to keep running averages of some values in the original dataset. + + ## Events and Handlers + + You can instrument several points in the loop using event handlers. By default, several events + are fired when running a loop: + + events = [ + :started, # After loop state initialization + :epoch_started, # On epoch start + :iteration_started, # On iteration start + :iteration_completed, # On iteration complete + :epoch_completed, # On epoch complete + :epoch_halted, # On epoch halt, if early halted + :halted, # On loop halt, if early halted + :completed # On loop completion + ] + + You can attach event handlers to events using `Axon.Loop.handle/4`: + + loop + |> Axon.Loop.handle(:iteration_completed, &log_metrics/1, every: 100) + |> Axon.Loop.run(data) + + The above will trigger `log_metrics/1` every 100 times the `:iteration_completed` event + is fired. Event handlers must return a tuple `{status, state}`, where `status` is an + atom with one of the following values: + + :continue # Continue epoch, continue looping + :halt_epoch # Halt the epoch, continue looping + :halt_loop # Halt looping + + And `state` is an updated `Axon.Loop.State` struct. Handler functions take as input + the current loop state. + + It's important to note that event handlers are triggered in the order they are attached + to the loop. If you have two handlers on the same event, they will trigger in order: + + loop + |> Axon.Loop.handle(:epoch_completed, &normalize_state/1) # Runs first + |> Axon.Loop.handle(:epoch_completed, &log_state/1) # Runs second + + You may provide filters to filter when event handlers trigger. See `Axon.Loop.handle/4` + for more details on valid filters. + + ## Factories + + Axon loops are typically created from one of the factory functions provided in this + module: + + * `Axon.Loop.loop/3` - Creates a loop from step function and optional initialization + functions and output transform functions. + + * `Axon.Loop.trainer/3` - Creates a supervised training loop from model, loss, and + optimizer. + + * `Axon.Loop.evaluator/2` - Creates a supervised evaluator loop from model and model + state. + + ## Running loops + + In order to execute a loop, you should use `Axon.Loop.run/3`: + + loop + |> Axon.Loop.run(data, epochs: 10) + + ## Resuming loops + + At times you may want to resume a loop from some previous state. You can accomplish this + with `Axon.Loop.from_state/2`: + + loop + |> Axon.Loop.from_state(state) + |> Axon.Loop.run(data) + """ + require Axon + require Axon.Updates + require Logger + + # TODO(seanmor5): Remove when running average is gone + import Nx.Defn + + alias __MODULE__, as: Loop + alias Axon.Loop.State + + @default_events [ + :started, + :epoch_started, + :iteration_started, + :iteration_completed, + :epoch_completed, + :epoch_halted, + :halted, + :completed + ] + + @default_handlers %{ + started: [], + epoch_started: [], + iteration_started: [], + iteration_completed: [], + epoch_completed: [], + epoch_halted: [], + halted: [], + completed: [] + } + + @valid_axon_losses [ + :binary_cross_entropy, + :categorical_cross_entropy, + :categorical_hinge, + :hinge, + :kl_divergence, + :log_cosh, + :mean_absolute_error, + :mean_squared_error, + :poisson, + :soft_margin + ] + + @valid_axon_optimizers [ + :adabelief, + :adagrad, + :adam, + :adamw, + :fromage, + :lamb, + :noisy_sgd, + :radam, + :rmsprop, + :sgd, + :yogi + ] + + @doc false + @derive {Inspect, only: [:metrics, :handlers]} + @enforce_keys [:init, :step] + defstruct [ + :init, + :step, + :attached_state, + :output_transform, + metrics: %{}, + handlers: @default_handlers + ] + + ## Step Factories + + @doc """ + Creates a supervised train step from a model, loss function, and + optimizer. + + This function is intended for more fine-grained control over the loop + creation process. It returns a tuple of `{init_fn, step_fn}` where `init_fn` + is an initialization function which returns an initial step state and + `step_fn` is a supervised train step constructed from `model`, `loss`, + and `optimizer`. + + `model` must be an Axon struct, a valid defn container + of Axon structs, or a `{init_fn, apply_fn}`-tuple where `init_fn` is + an arity-0 function which initializes the model state and `apply_fn` is + an arity-2 function which applies the forward pass of the model. + + `loss` must be an atom which matches a function in `Axon.Losses`, a list + of `{loss, weight}` tuples representing a basic weighted loss function + for multi-output models, or an arity-2 function representing a custom loss + function. + + `optimizer` must be an atom matching the name of a valid optimizer in `Axon.Optimizers`, + or a `{init_fn, update_fn}` tuple where `init_fn` is an arity-1 function which + initializes the optimizer state from attached parameters and `update_fn` is an + arity-3 function which scales gradient updates with respect to input parameters, + optimizer state, and gradients. See `Axon.Updates` for more information on building + optimizers. + """ + def train_step(model, loss, optimizer) do + {init_model_fn, forward_model_fn} = build_model_fns(model, :train) + loss_fn = build_loss_fn(loss) + {init_optimizer_fn, update_optimizer_fn} = build_optimizer_fns(optimizer) + + init_fn = fn -> + model_state = init_model_fn.() + optimizer_state = init_optimizer_fn.(model_state) + + %{ + i: Nx.tensor(0, backend: Nx.Defn.Expr), + y_true: Nx.tensor(0.0, backend: Nx.Defn.Expr), + y_pred: Nx.tensor(0.0, backend: Nx.Defn.Expr), + loss: Nx.tensor(0.0, backend: Nx.Defn.Expr), + model_state: model_state, + optimizer_state: optimizer_state + } + end + + objective_fn = fn state, inp, tar -> + y_pred = forward_model_fn.(state, inp) + {y_pred, loss_fn.(tar, y_pred)} + end + + step_fn = fn {inp, tar}, state -> + %{i: i, model_state: model_state, optimizer_state: optimizer_state, loss: loss} = state + + {{preds, batch_loss}, gradients} = + Nx.Defn.value_and_grad( + model_state, + &objective_fn.(&1, inp, tar), + fn x -> elem(x, 1) end + ) + + new_loss = running_average(loss, batch_loss, i) + + {updates, new_optimizer_state} = + update_optimizer_fn.(gradients, optimizer_state, model_state) + + %{ + i: Nx.add(i, 1), + y_true: tar, + y_pred: preds, + loss: new_loss, + model_state: Axon.Updates.apply_updates(model_state, updates), + optimizer_state: new_optimizer_state + } + end + + {init_fn, step_fn} + end + + @doc """ + Creates a supervised evaluation step from a model and model state. + + This function is intended for more fine-grained control over the loop + creation process. It returns a tuple of `{init_fn, step_fn}` where + `init_fn` returns an initial step state and `step_fn` performs a + single evaluation step. + """ + def eval_step(model, model_state) do + {_, forward_model_fn} = build_model_fns(model, :inference) + + init_fn = fn -> + %{ + y_true: Nx.tensor(0.0, backend: Nx.Defn.Expr), + y_pred: Nx.tensor(0.0, backend: Nx.Defn.Expr) + } + end + + step_fn = fn {inp, tar}, _ -> + %{ + y_true: tar, + y_pred: forward_model_fn.(model_state, inp) + } + end + + {init_fn, step_fn} + end + + ## Loop Factories + + @doc """ + Creates a loop from `step_fn`, an optional `init_fn`, and an + optional `output_transform`. + + `step_fn` is an arity-2 function which takes a batch and state + and returns an updated step state: + + defn batch_step(batch, step_state) do + step_state + 1 + end + + `init_fn` by default is a function which returns an empty map. You should + define your own if subsequent step state updates rely on an initial + step state: + + defn init_step_state() do + 0 + end + + `step_batch/2` and `init_step_state/0` are typically called from + within `Nx.Defn.jit/3`. While JIT-compilation will work with anonymous functions, + `def`, and `defn`, it is recommended that you use the stricter `defn` to define + both functions in order to avoid bugs or cryptic errors. + + `output_transform/1` applies a transformation on the final accumulated loop state. + This is useful for extracting specific fields from a loop and piping them into + additional functions. + """ + def loop(step_fn, init_fn \\ fn -> %{} end, output_transform \\ & &1) + when is_function(step_fn, 2) and is_function(init_fn, 0) and + is_function(output_transform, 1) do + %Loop{ + init: init_fn, + step: step_fn, + output_transform: output_transform + } + end + + @doc """ + Creates a supervised training loop from a model, loss function, + and optimizer. + + This function is useful for training models on most standard supervised + learning tasks. It assumes data consists of tuples of input-target pairs, + e.g. `[{x0, y0}, {x1, y1}, ..., {xN, yN}]` where `x0` and `y0` are batched + tensors or containers of batched tensors. + + It defines an initialization function which first initializes model state + using the given model and then initializes optimizer state using the initial + model state. The step function uses a differentiable objective function + defined with respect to the model parameters, input data, and target data + using the given loss function. It then updates model parameters using the + given optimizer in order to minimize loss with respect to the model parameters. + + `model` must be an Axon struct, a valid defn container + of Axon structs, or a `{init_fn, apply_fn}`-tuple where `init_fn` is + an arity-0 function which initializes the model state and `apply_fn` is + an arity-2 function which applies the forward pass of the model. + + `loss` must be an atom which matches a function in `Axon.Losses`, a list + of `{loss, weight}` tuples representing a basic weighted loss function + for multi-output models, or an arity-2 function representing a custom loss + function. + + `optimizer` must be an atom matching the name of a valid optimizer in `Axon.Optimizers`, + or a `{init_fn, update_fn}` tuple where `init_fn` is an arity-1 function which + initializes the optimizer state from attached parameters and `update_fn` is an + arity-3 function which scales gradient updates with respect to input parameters, + optimizer state, and gradients. See `Axon.Updates` for more information on building + optimizers. + + This function creates a step function which outputs a map consisting of the following + fields for `step_state`: + + %{ + y_pred: tensor() | container(tensor()), # Model predictions for use in metrics + y_true: tensor() | container(tensor()), # True labels for use in metrics + loss: tensor(), # Running average of loss over epoch + model_state: container(tensor()), # Model parameters and state + optimizer_state: container(tensor()) # Optimizer state assosciated with each parameter + } + + ## Examples + + ### Basic usage + + data = Stream.zip(input, target) + + model = Axon.input({nil, 32}) |> Axon.dense(1, activation: :sigmoid) + + model + |> Axon.Loop.trainer(:binary_cross_entropy, :adam) + |> Axon.Loop.run(data) + + ### Customizing Optimizer + + model + |> Axon.Loop.trainer(:binary_cross_entropy, Axon.Optimizers.adam(0.05)) + |> Axon.Loop.run(data) + + ### Custom loss + + loss_fn = fn y_true, y_pred -> Nx.cos(y_true, y_pred) end + + model + |> Axon.Loop.trainer(loss_fn, Axon.Optimizers.rmsprop(0.01)) + |> Axon.Loop.run(data) + + ### Multiple objectives with multi-output model + + model = {Axon.input({nil, 1}), Axon.input({nil, 2})} + loss_weights = [mean_squared_error: 0.5, mean_absolute_error: 0.5] + + model + |> Axon.Loop.trainer(loss_weights) + |> Axon.Loop.run(data) + """ + def trainer(model, loss, optimizer) do + {init_fn, step_fn} = train_step(model, loss, optimizer) + output_transform = fn state -> state.step_state[:model_state] end + loop(step_fn, init_fn, output_transform) + end + + @doc """ + Creates a supervised evaluator from a model and model state. + + An evaluator can be used for things such as testing and validation of models + after or during training. It assumes `model` is an Axon struct, container of + structs, or a tuple of `init` / `apply` functions. `model_state` must be a + container useable from within `model`. + + The evaluator returns a step state of the form: + + %{ + y_true: labels, + y_pred: predictions + } + + Such that you can attach any number of supervised metrics to the evaluation + loop: + + model + |> Axon.Loop.evaluator(trained_state) + |> Axon.Loop.metric("Accuracy", :accuracy) + + Applies an output transform which returns the map of metrics accumulated over + the given loop. + """ + def evaluator(model, model_state) do + {init_fn, step_fn} = eval_step(model, model_state) + output_transform = fn state -> state.metrics end + loop(step_fn, init_fn, output_transform) + end + + @doc """ + Adds a metric of the given name to the loop. + + A metric is a function which tracks or measures some value with respect + to values in the step state. For example, when training classification + models, it's common to track the model's accuracy during training: + + loop + |> Axon.Loop.metric(:accuracy, "Accuracy") + + By default, metrics assume a supervised learning task and extract the fields + `[:y_true, :y_pred]` from the step state. If you wish to work on a different + value, you can use an output transform. An output transform is a list of keys + to extract from the output state, or a function which returns a flattened list + of values to pass to the given metric function. Values received from output + transforms are passed to the given metric using: + + value = output_transform.(step_state) + apply(metric, value) + + Thus, even if you want your metric to work on a container, your output transform + must return a list. + + `metric` must be an atom which matches the name of a metric in `Axon.Metrics`, or + an arbitrary function which returns a tensor or container. + + `name` must be a string or atom used to store the computed metric in the loop + state. If names conflict, the last attached metric will take precedence: + + loop + |> Axon.Loop.metric(:mean_squared_error, "Error") # Will be overwritten + |> Axon.Loop.metric(:mean_absolute_error, "Error") # Will be used + """ + def metric( + %Loop{metrics: metric_fns} = loop, + metric, + name, + transform_or_fields \\ [:y_true, :y_pred] + ) do + case metric_fns do + %{^name => _} -> + Logger.warning( + "Metric #{name} declared twice in loop. Original metric will be overriden." + ) + + _ -> + :ok + end + + metric_fn = build_metric_fn(metric, transform_or_fields) + %Loop{loop | metrics: Map.put(metric_fns, name, metric_fn)} + end + + @doc """ + Adds a handler function to the loop which will be triggered on `event` + with an optional filter. + + Events take place at different points during loop execution. The default + events are: + + events = [ + :started, # After loop state initialization + :epoch_started, # On epoch start + :iteration_started, # On iteration start + :iteration_completed, # On iteration complete + :epoch_completed, # On epoch complete + :epoch_halted, # On epoch halt, if early halted + :halted, # On loop halt, if early halted + :completed # On loop completion + ] + + Generally, event handlers are side-effecting operations which provide some + sort of inspection into the loop's progress. It's important to note that + if you define multiple handlers to be triggered on the same event, they + will execute in order from when they were attached to the training + loop: + + loop + |> Axon.Loop.handle(:epoch_started, &normalize_step_state/1) # executes first + |> Axon.Loop.handle(:epoch_started, &log_step_state/1) # executes second + + Thus, if you have separate handlers which alter or depend on loop state, + you need to ensure they are ordered correctly, or combined into a single + event handler for maximum control over execution. + + `event` must be an atom representing the event to trigger `handler` or a + list of atoms indicating `handler` should be triggered on multiple events. + `event` may be `:all` which indicates the handler should be triggered on + every event during loop processing. + + `handler` must be an arity-1 function which takes as input loop state and + returns `{status, state}`, where `status` is an atom with one of the following + values: + + :continue # Continue epoch, continue looping + :halt_epoch # Halt the epoch, continue looping + :halt_loop # Halt looping + + `filter` is an atom representing a valid filter predicate, a keyword of + predicate-value pairs, or a function which takes loop state and returns + a `true`, indicating the handler should run, or `false`, indicating the + handler should not run. Valid predicates are: + + :always # Always trigger event + :once # Trigger on first event firing + + Valid predicate-value pairs are: + + every: N # Trigger every `N` event + only: N # Trigger on `N` event + """ + # TODO(seanmor5): Custom events + def handle(%Loop{handlers: handle_fns} = loop, event, handler, filter \\ :always) do + filter = build_filter_fn(filter) + + handle_fns = + case event do + [_ | _] = events -> + Enum.reduce(events, handle_fns, &add_event_handler(&1, &2, {handler, filter})) + + :all -> + Enum.reduce(@default_events, handle_fns, &add_event_handler(&1, &2, {handler, filter})) + + event when is_atom(event) -> + add_event_handler(event, handle_fns, {handler, filter}) + end + + %Loop{loop | handlers: handle_fns} + end + + @doc """ + Attaches `state` to the given loop in order to resume looping + from a previous state. + + It's important to note that a loop's attached state takes precedence + over defined initialization functions. Given initialization function: + + defn init_state(), do: %{foo: 1, bar: 2} + + And an attached state: + + state = %State{step_state: %{foo: 2, bar: 3}} + + `init_state/0` will never execute, and instead the initial step state + of `%{foo: 2, bar: 3}` will be used. + """ + def from_state(%Loop{} = loop, %State{} = state) do + %{loop | attached_state: state} + end + + @doc """ + Runs the given loop on data with the given options. + + `loop` must be a valid Axon.Loop struct built from one of the + loop factories provided in this module. + + `data` must be an Enumerable or Stream which yields batches of + data on each iteration. + + ## Options + + * `:epochs` - max epochs to run loop for. Must be non-negative integer. + Defaults to `1`. + + * `:iterations` - max iterations to run each epoch. Must be non-negative + integer. Defaults to `nil` or no max iterations. + + * `:jit_compile?` - whether or not to JIT compile initialization and step + functions. JIT compilation must be used for gradient computations. Defaults + to true. + + * `:compiler` - Nx compiler to use to JIT compile step function. Defaults + to `nil` or Nx.Defn.Evaluator. + """ + def run(loop, data, opts \\ []) do + {max_epochs, opts} = Keyword.pop(opts, :epochs, 1) + {max_iterations, opts} = Keyword.pop(opts, :iterations, -1) + {jit_compile?, opts} = Keyword.pop(opts, :jit_compile?, true) + {compiler, jit_opts} = Keyword.pop(opts, :compiler, Nx.Defn.Evaluator) + + %Loop{ + init: init_fn, + step: step_fn, + handlers: handler_fns, + metrics: metric_fns, + attached_state: attached_state, + output_transform: output_transform + } = loop + + loop_state = + init_loop_state( + init_fn, + attached_state, + metric_fns, + max_epochs, + max_iterations, + jit_compile?, + compiler, + jit_opts + ) + + {status, state} = + case fire_event(:started, handler_fns, loop_state) do + {:halt_epoch, state} -> + {:halted, state} + + {:halt_loop, state} -> + {:halted, state} + + {:continue, state} -> + Enum.reduce_while(0..(max_epochs - 1)//1, {:completed, state}, fn epoch, + {_, loop_state} -> + case fire_event(:epoch_started, handler_fns, loop_state) do + {:halt_epoch, state} -> + halt_epoch(handler_fns, state) + + {:halt_loop, state} -> + {:halt, {:halted, state}} + + {:continue, state} -> + {time, status_and_state} = + :timer.tc(&run_epoch/8, [ + step_fn, + metric_fns, + handler_fns, + state, + data, + jit_compile?, + compiler, + jit_opts + ]) + + case status_and_state do + {:halt_epoch, state} -> + halt_epoch(handler_fns, state) + + {:halt_loop, state} -> + {:halt, {:halted, state}} + + {:continue, state} -> + new_times = Map.put(state.times, Nx.to_scalar(epoch), time) + new_loop_state = %State{state | times: new_times} + + case fire_event(:epoch_completed, handler_fns, new_loop_state) do + {:halt_epoch, state} -> + halt_epoch(handler_fns, state) + + {:halt_loop, state} -> + {:halt, {:halted, state}} + + {:continue, state} -> + max_iter = state.iteration + + {:cont, + {:completed, + %State{state | epoch: epoch + 1, iteration: 0, max_iteration: max_iter}}} + end + end + end + end) + end + + {_, state} = fire_event(status, handler_fns, state) + + output_transform.(state) + end + + ## Helpers + + defp init_loop_state( + init_fn, + attached_state, + metric_fns, + max_epochs, + max_iterations, + jit_compile?, + compiler, + jit_opts + ) do + case attached_state do + %State{} = state -> + state + + nil -> + metrics = Map.new(metric_fns, fn {k, _} -> {k, Nx.tensor(0.0)} end) + step_state = maybe_jit(init_fn, [], jit_compile?, compiler, jit_opts) + + %State{ + epoch: 0, + max_epoch: max_epochs, + iteration: 0, + max_iteration: max_iterations, + step_state: step_state, + metrics: metrics, + times: %{} + } + end + end + + defp run_epoch( + step_fn, + metric_fns, + handler_fns, + loop_state, + data, + jit_compile?, + compiler, + jit_opts + ) do + Enum.reduce_while(data, {:continue, loop_state}, fn data, {_, state} -> + case fire_event(:iteration_started, handler_fns, state) do + {:halt_epoch, state} -> + {:halt, {:halt_epoch, state}} + + {:halt_loop, state} -> + {:halt, {:halt_loop, state}} + + {:continue, state} -> + batch_fn = build_batch_fn(step_fn, metric_fns) + + %State{iteration: iters, max_iteration: max_iters} = + new_state = maybe_jit(batch_fn, [data, state], jit_compile?, compiler, jit_opts) + + case fire_event(:iteration_completed, handler_fns, new_state) do + {:halt_epoch, state} -> + {:halt, {:halt_epoch, state}} + + {:halt_loop, state} -> + {:halt, {:halt_loop, state}} + + {:continue, state} -> + iters = Nx.to_scalar(iters) + max_iters = Nx.to_scalar(max_iters) + + if iters > max_iters and max_iters != -1 do + {:halt, {:continue, state}} + else + {:cont, {:continue, state}} + end + end + end + end) + end + + # Adds an event handler to the map of handler funs by prepending handler + # to the existing handler funs. Because we prepend here, we must reverse + # handler funs in fire_event. + # TODO(seanmor5): Custom events + defp add_event_handler(event, handle_fns, handler) do + Map.update!(handle_fns, event, fn event_funs -> [handler | event_funs] end) + end + + # Fires event `event` using handler_fns assosciated with the event. We + # must reverse handler funs in order to enforce order that handlers are + # attached to the loop. + # TODO(seanmor5): Custom events + defp fire_event(event, handler_fns, state) do + handler_fns[event] + |> Enum.reverse() + |> Enum.reduce_while({:continue, state}, fn {handler, filter}, {_, state} -> + if filter.(state) do + case handler.(state) do + {:continue, %State{} = state} -> + {:cont, {:continue, state}} + + {:halt_epoch, %State{} = state} -> + {:halt, {:halt_epoch, state}} + + {:halt_loop, %State{} = state} -> + {:halt, {:halt_loop, state}} + + invalid -> + raise ArgumentError, + "invalid value #{inspect(invalid)} returned from event handler" <> + " triggered on #{inspect(event)}, event handler must return" <> + " a tuple of {status, state} where status is one of :halt_epoch," <> + " :halt_loop, or :continue and state is an updated State struct" + end + else + {:cont, {:continue, state}} + end + end) + end + + # Halts an epoch during looping + defp halt_epoch(handler_fns, loop_state) do + case fire_event(:epoch_halted, handler_fns, loop_state) do + {:halt_epoch, state} -> + {:cont, %State{state | epoch: state.epoch + 1, iteration: 0}} + + {:halt_loop, state} -> + {:halt, {:halted, state}} + + {:continue, state} -> + {:cont, state} + end + end + + # Builds the overall batch step function from the given + # step function and metrics. We need to run both step and metric + # functions from within here to ensure they can be JIT compiled + # if that's desired + defp build_batch_fn(step_fn, metric_fns) do + fn data, state -> + %State{metrics: metrics, iteration: iter, step_state: pstate} = state + new_step_state = step_fn.(data, pstate) + + new_metrics = + metrics + |> Enum.zip_with(metric_fns, fn {k, avg}, {k, v} -> + {k, running_average(avg, v.(new_step_state), iter)} + end) + |> Map.new() + + %State{ + state + | iteration: Nx.add(iter, 1), + step_state: new_step_state, + metrics: new_metrics + } + end + end + + # Builds a loss function from an atom, function, or list of. Valid loss + # functions must be one of an atom matching the name of a function in + # Axon.Losses, an arity-2 function of the form loss(y_true, y_pred), + # or a list of 2-tuples of {loss, weight} for constructing a simple + # joint, multi-objective loss function. + # TODO(seanmor5): Configurable per-batch reductions + # TODO(seanmor5): Configurable multi-objective reductions + # TODO(seanmor5): Should we trace custom loss functions and provide a + # more clear error if the output shape is wrong? + defp build_loss_fn(loss) do + case loss do + loss_name when is_atom(loss_name) and loss_name in @valid_axon_losses -> + &apply(Axon.Losses, loss_name, [&1, &2, [reduction: :mean]]) + + loss_fn when is_function(loss, 2) -> + loss_fn + + [{_, _} | _] = losses -> + fn y_true, y_pred -> + {_, loss} = + Enum.reduce(losses, {0, Nx.tensor(0, backend: Nx.Defn.Expr)}, fn {loss, weight}, + {i, acc_loss} -> + loss_fn = build_loss_fn(loss) + + y_true_i = elem(y_true, i) + y_pred_i = elem(y_pred, i) + + new_acc_loss = + y_true_i + |> loss_fn.(y_pred_i) + |> Nx.multiply(weight) + |> Nx.add(acc_loss) + + {i + 1, new_acc_loss} + end) + + loss + end + + invalid -> + raise ArgumentError, + "Invalid loss function #{inspect(invalid)}, a valid loss" <> + " function is an atom which matches a function in Axon.Losses," <> + " an arity-2 function of the form loss(y_true, y_pred), or a list" <> + " of 2-tuples of {loss, weight} for multi-objective models" + end + end + + # Builds model init and forward functions from an Axon struct, + # a tuple of Axon structs, or a tuple of init / forward + # functions. Model functions are essentially just model + # init / apply functions. + # TODO(seanmor5): Update this to support any valid defn container + defp build_model_fns(%Axon{} = model, mode) do + Axon.compile(model, mode: mode) + end + + defp build_model_fns({init_fn, forward_fn}, _) + when is_function(init_fn, 0) and is_function(forward_fn, 2) do + {init_fn, forward_fn} + end + + defp build_model_fns(model, mode) when is_tuple(model) do + Axon.compile(model, mode: mode) + end + + defp build_model_fns(invalid, _) do + raise ArgumentError, + "Invalid model #{inspect(invalid)}, a valid model" <> + " is an Axon struct, a container of Axon structs " <> + " or a tuple of {init_fn, forward_fn} with signatures" <> + " init_fn() :: model_state, forward_fn(model_state, inp) :: prediction" + end + + # Builds optimizer init and update functions either from an atom + # or a tuple of init / update functions. The init and update functions + # match the signatures of those defined in Axon.Updates. If the + # optimizer is an atom, it must match the name of a function in + # Axon.Optimizers. + defp build_optimizer_fns(optimizer) + when is_atom(optimizer) and optimizer in @valid_axon_optimizers do + # TODO(seanmor5): Fall back to optimizer defaults rather + # than this global default. + apply(Axon.Optimizers, optimizer, [1.0e-2]) + end + + defp build_optimizer_fns({init_optimizer_fn, update_optimizer_fn}) + when is_function(init_optimizer_fn, 1) and is_function(update_optimizer_fn, 3) do + {init_optimizer_fn, update_optimizer_fn} + end + + defp build_optimizer_fns(invalid) do + raise ArgumentError, + "Invalid optimizer #{inspect(invalid)}, a valid optimizer" <> + " is an atom matching the name of an optimizer in Axon.Optimizers" <> + " or a tuple of {init_fn, update_fn}. See Axon.Updates for more" <> + " information on building optimizers using the low-level API" + end + + # Builds a metric function from an atom or function and an output transform. + # A valid metric is an atom which matches the name of a function in + # Axon.Metrics or a function which takes an arbitrary number of parameters + # and returns an output of arbitrary shape/type. Output transforms are field(s) + # to extract from the step state, or a function which transforms the step + # state before it is passed to the metric function. + # TODO(seanmor5): Reconsider the form of output transform + defp build_metric_fn(metric, transform_or_fields) do + transform_fn = + case transform_or_fields do + [_ | _] = fields -> + fn output -> + fields + |> Enum.reduce([], fn field, acc -> [output[field] | acc] end) + |> Enum.reverse() + end + + field when is_atom(field) -> + fn output -> + output[field] + end + + transform when is_function(transform, 1) -> + transform + + invalid -> + raise ArgumentError, + "Invalid output transform #{inspect(invalid)}, a valid output" <> + " transform is an atom or list of atoms specifying field(s)" <> + " to extract from the step state, or an arity-1 function" <> + " applied to the step state" + end + + case metric do + metric when is_atom(metric) -> + fn output -> + output + |> transform_fn.() + |> then(&apply(Axon.Metrics, metric, &1)) + end + + metric_fn when is_function(metric) -> + fn output -> + output + |> transform_fn.() + |> then(&apply(metric_fn, &1)) + end + + invalid -> + raise ArgumentError, + "Invalid metric #{inspect(invalid)}, a valid metric" <> + " is an atom which matches the name of a function in" <> + " Axon.Metrics or a function which takes a transformed" <> + " step state and returns a value" + end + end + + # Builds a filter function from an atom, keyword list, or function. A + # valid filter is an atom which matches on of the valid predicates `:always` + # or `:once`, a keyword which matches one of the valid predicate-value pairs + # such as `every: N`, or a function which takes loop state and returns `true` + # or `false`. + # + # TODO(seanmor5): In order to handle custom events and predicate filters, + # we will need to track event firings in the loop state. + defp build_filter_fn(filter) do + case filter do + :always -> + fn _ -> true end + + :once -> + fn + %State{epoch: 0, iteration: 0} -> true + _ -> false + end + + [{:every, n} | _] -> + fn %State{iteration: iter} -> + Nx.remainder(iter, n) == Nx.tensor(0) + end + + fun when is_function(fun, 1) -> + fun + + invalid -> + raise ArgumentError, + "Invalid filter #{inspect(invalid)}, a valid filter" <> + " is an atom which matches a valid filter predicate" <> + " such as :always or :once, a keyword of predicate-value" <> + " pairs such as every: N, or an arity-1 function which takes" <> + " loop state and returns true or false" + end + end + + # JIT-compiles the given function if the given compiler is a + # valid defn compiler, otherwise applies the function with + # the given arguments. + defp maybe_jit(fun, args, jit_compile?, compiler, jit_opts) do + if jit_compile? do + Nx.Defn.jit(fun, args, [compiler: compiler] ++ jit_opts) + else + apply(fun, args) + end + end + + # TODO(seanmor5): Move to metrics as a combinator + defnp running_average(avg, value, i) do + avg + |> Nx.multiply(i) + |> Nx.add(value) + |> Nx.divide(Nx.add(i, 1)) + end +end diff --git a/lib/axon/loop/process.ex b/lib/axon/loop/process.ex new file mode 100644 index 00000000..c83240a4 --- /dev/null +++ b/lib/axon/loop/process.ex @@ -0,0 +1,18 @@ +defmodule Axon.Loop.Process do + @moduledoc false + + # Process function which runs iteratively within a loop, + # reducing over data and accumulating process state. The process + # state is initialized from `:init` and updated at each iteration + # with `:update`. + + # Fields + # + # :init - Initialization of process state, loops are modeled + # as a reduction over some data with the process state as the + # accumulator. This will initialize the state of the accumulator + # + # :update - Process function or update function. Performs processing + # and updates of the process state. + defstruct [:init, :update] +end diff --git a/lib/axon/loop/state.ex b/lib/axon/loop/state.ex new file mode 100644 index 00000000..1f51be20 --- /dev/null +++ b/lib/axon/loop/state.ex @@ -0,0 +1,51 @@ +defmodule Axon.Loop.State do + @moduledoc """ + Accumulated state in an Axon.Loop. + + Loop state is a struct: + + %State{ + epoch: tensor(), + max_epoch: tensor(), + iteration: tensor(), + max_iteration: tensor(), + metrics: map(string(), container()), + times: list(number()), + step_state: container() + } + + `epoch` is the current epoch, starting at 0, of the nested loop. + Defaults to 0. + + `max_epoch` is the maximum number of epochs the loop should run + for. Defaults to 1. + + `iteration` is the current iteration of the inner loop. In supervised + settings, this will be the current batch. Defaults to 0. + + `max_iteration` is the maximum number of iterations the loop should + run a given epoch for. Defaults to -1 (no max). + + `metrics` is a map of `%{"metric_name" => value}` which accumulates metrics + over the course of loop processing. Defaults to an empty map. + + `times` is a map of `%{epoch_number => value}` which maps a given epoch + to the processing time. Defaults to an empty map. + + `step_state` is the step state as defined by the loop's processing + initialization and update functions. `step_state` is a required field. + """ + # TODO(seanmor5): We should not send `:times` to the device. We need + # a way in Nx/EXLA to mark `:times` as a static property which is + # not to be touched at JIT time. + @enforce_keys [:step_state] + defstruct [ + :step_state, + epoch: 0, + max_epoch: 1, + iteration: 0, + max_iteration: -1, + metrics: %{}, + times: %{} + ] +end diff --git a/lib/axon/optimizers.ex b/lib/axon/optimizers.ex index 61a6b168..7e3e8bec 100644 --- a/lib/axon/optimizers.ex +++ b/lib/axon/optimizers.ex @@ -48,8 +48,8 @@ defmodule Axon.Optimizers do For a simpler approach, you can also use optimizers with the training API: model - |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adam(0.005)) - |> Axon.Training.train(train_images, train_labels, epochs: 10, compiler: EXLA) + |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adam(0.005)) + |> Axon.Loop.run(data, epochs: 10, compiler: EXLA) """ alias Axon.Updates diff --git a/lib/axon/training.ex b/lib/axon/training.ex deleted file mode 100644 index 4b714cc8..00000000 --- a/lib/axon/training.ex +++ /dev/null @@ -1,427 +0,0 @@ -defmodule Axon.Training do - @moduledoc """ - Abstractions for training machine learning models. - """ - require Axon - require Axon.Updates - - import Axon.Shared - - alias Axon.Training.Step - - @doc false - def step({_, _} = model, {_, _} = update), do: step(model, update, []) - - @doc """ - Represents a single training step. - - The first two arguments are tuples: - - * The first tuple contains the model initialization function - and the objective function. For a Neural Network, the objective - function is the loss function of the Neural Network prediction - - * The second pairs contains the updater initialization function - and the update function itself - - ## Options - - * `:metrics` - metrics to track during each training step. Can be an - atom representing a function in `Axon.Metrics`, or a 2-arity function - taking `y_true` and `y_pred` as args. - - """ - def step({init_model_fn, objective_fn}, {init_update_fn, update_fn}, opts) - when is_function(init_model_fn, 0) and is_function(objective_fn, 3) and - is_function(init_update_fn, 1) and is_function(update_fn, 3) and is_list(opts) do - metrics = opts[:metrics] || [] - - update_metrics_fn = fn old_metrics, step, y_true, y_pred -> - Map.new(metrics, fn - {key, fun} -> - batch_metric = fun.(y_true, y_pred) - - avg_metric = - old_metrics[key] - |> Nx.multiply(step) - |> Nx.add(batch_metric) - |> Nx.divide(Nx.add(step, 1)) - - {key, avg_metric} - - key -> - batch_metric = apply(Axon.Metrics, key, [y_true, y_pred]) - - avg_metric = - old_metrics[key] - |> Nx.multiply(step) - |> Nx.add(batch_metric) - |> Nx.divide(Nx.add(step, 1)) - - {key, avg_metric} - end) - end - - init_fn = fn -> - params = init_model_fn.() - optim_params = init_update_fn.(params) - - init_metrics = Map.new(metrics, fn k -> {k, Nx.tensor(0.0, backend: Nx.Defn.Expr)} end) - - %{ - epoch: Nx.tensor(0, backend: Nx.Defn.Expr), - epoch_step: Nx.tensor(0, backend: Nx.Defn.Expr), - epoch_loss: Nx.tensor(0.0, backend: Nx.Defn.Expr), - params: params, - optimizer_state: optim_params, - metrics: init_metrics - } - end - - step_fn = fn train_state, input, target -> - {{preds, batch_loss}, gradients} = - Nx.Defn.value_and_grad( - train_state[:params], - &objective_fn.(&1, input, target), - fn x -> elem(x, 1) end - ) - - new_metrics = - case metrics do - [] -> - %{} - - _ -> - update_metrics_fn.(train_state[:metrics], train_state[:epoch_step], target, preds) - end - - epoch_avg_loss = - train_state[:epoch_loss] - |> Nx.multiply(train_state[:epoch_step]) - |> Nx.add(batch_loss) - |> Nx.divide(Nx.add(train_state[:epoch_step], 1)) - - {updates, new_update_state} = - update_fn.(gradients, train_state[:optimizer_state], train_state[:params]) - - updates = - deep_merge(updates, train_state[:params], fn g, x -> Nx.as_type(g, Nx.type(x)) end) - - %{ - epoch: train_state[:epoch], - epoch_step: Nx.add(train_state[:epoch_step], 1), - epoch_loss: epoch_avg_loss, - params: Axon.Updates.apply_updates(train_state[:params], updates), - optimizer_state: new_update_state, - metrics: new_metrics - } - end - - %Step{init: init_fn, step: step_fn, callbacks: []} - end - - @doc false - def step(%Axon{} = model, loss, {_, _} = optimizer) when is_function(loss, 2) or is_atom(loss), - do: step(model, loss, optimizer, []) - - @doc """ - Represents a single training step using an Axon `model`, - `loss` function, and `optimizer`. - - The `loss` function is either an atom or a two arity - anonymous function. - """ - def step(%Axon{} = model, loss, optimizer, opts) - when is_function(loss, 2) and is_list(opts) do - {init_fn, predict_fn} = Axon.compile(model) - - objective_fn = fn params, input, target -> - preds = predict_fn.(params, input) - loss = Nx.add(loss.(target, preds), Axon.penalty(model, params)) - {preds, loss} - end - - step({init_fn, objective_fn}, optimizer, opts) - end - - def step(%Axon{} = model, loss, optimizer, opts) when is_atom(loss) and is_list(opts) do - loss_fn = &apply(Axon.Losses, loss, [&1, &2, [reduction: :mean]]) - step(model, loss_fn, optimizer, opts) - end - - @doc false - def step(%Axon{} = model, train_state, loss, {_, _} = optimizer) - when is_function(loss, 2) or is_atom(loss), - do: step(model, train_state, loss, optimizer, []) - - @doc """ - Represents a single training step using an Axon `model`, - initial state `train_state`, `loss` function and `optimizer`. - - The `loss` function is either an atom or a two arity anonymous - function. - """ - def step(%Axon{} = model, train_state, loss, optimizer, opts) - when is_function(loss, 2) and is_list(opts) do - init_fn = fn -> - train_state - |> Tuple.to_list() - |> Enum.map(&Nx.tensor(&1, backend: Nx.Defn.Expr)) - |> List.to_tuple() - end - - objective_fn = fn params, input, target -> - preds = Axon.predict(model, params, input, mode: :train) - Nx.add(loss.(target, preds), Axon.penalty(model, params)) - end - - step({init_fn, objective_fn}, optimizer, opts) - end - - def step(%Axon{} = model, train_state, loss, optimizer, opts) - when is_atom(loss) and is_list(opts) do - loss_fn = &apply(Axon.Losses, loss, [&1, &2, [reduction: :mean]]) - step(model, train_state, loss_fn, optimizer, opts) - end - - @valid_callbacks [:early_stopping] - - @doc false - def callback(%Step{} = step, callback) when callback in @valid_callbacks do - callback(step, callback, []) - end - - @doc """ - Adds a callback from `Axon.Training.Callbacks` to the training step. - """ - def callback(%Step{} = step, callback, opts) when callback in @valid_callbacks do - fun = &apply(Axon.Training.Callbacks, callback, [&1, &2, &3]) - callback(step, fun, :all, opts) - end - - @doc """ - Adds a callback function to the training step. - - Callback functions instrument specific points in the training loop. - You can specify an `event` which is one of: - - - `:before_{train, epoch, batch}` - - `:after_{train, epoch, batch}` - - The default `event` is `:all`, meaning the callback will run at every - callback point. - - Callback functions have the following signature: - - callback_fn(train_state :: map, event :: atom, opts :: keyword) :: - {:cont, train_state} | {:halt, train_state} - - You can trigger event-specific behavior using pattern matching: - - def my_callback(train_state, :before_epoch, _opts) do - {:cont, %{train_state | my_metadata: 0}} - end - - def my_callback(train_state, :after_epoch, _opts) do - {:cont, %{train_state | my_metadata: train_state[:metadata] + 1}} - end - - def my_callback(train_state, _event, _opts), do: {:cont, train_state} - - Returning `{:halt, train_state}` will immediately terminate the training loop: - - def early_stopping(train_state, :after_epoch, opts) do - if stop?(train_state, opts) do - {:halt, train_state} - else - {:cont, train_state} - end - end - """ - def callback(%Step{callbacks: callbacks} = step, function, event \\ :all, opts \\ []) - when is_function(function, 3) and is_atom(event) and is_list(opts) do - %{step | callbacks: [{function, event, opts} | callbacks]} - end - - @doc """ - Implements a common training loop. - - Its arguments are: - - * A tuple with the initialization function and the step function. - Often retrieved from `step/3` but it could also be manually provided. - - * The inputs tensors - - * The targets tensors - - * A list of options - - ## Options - - * `:epochs` - number of epochs to train for. Defaults to `5`. - * `:compiler` - `defn` compiler to use to run training loop. - Defaults to `Nx.Defn.Evaluator`. - * `:log_every` - frequency with which to log training loss. - Accepts an integer referring to number of batches, `:epoch`, - or `:none`. Defaults to `:epoch`. - - All other options are given to the underlying compiler. - - ## A note on Nx and anonymous functions - - When training, both `init_fn` and `step_fn` are executed within - the given Nx `:compiler`. Therefore, it is required that `init_fn` - and `step_fn` work on tensor expressions instead of tensor values. - - For example, let's suppose you want to initialize the values with: - - Nx.random_uniform({40, 28}, 0, 1) - - The following won't work: - - params = Nx.random_uniform({40, 28}, 0, 1) - init_fn = fn -> params end - - Instead, we want to build the values inside the given compiler. - The correct way to build those values is by computing them inside - a defn: - - defn init_values, do: Nx.random_uniform({40, 28}, 0, 1) - - And then: - - init_fn = &init_values/0 - - """ - def train( - %Step{init: init_fn, step: step_fn, callbacks: callbacks}, - inputs, - targets, - opts \\ [] - ) do - epochs = opts[:epochs] || 5 - compiler = opts[:compiler] || Nx.Defn.Evaluator - log_every = opts[:log_every] || 50 - - callbacks = [ - {&Axon.Training.Callbacks.standard_io_logger(&1, &2, &3), :all, log_every: log_every} - | Enum.reverse(callbacks) - ] - - jit_opts = [compiler: compiler] ++ opts - train_state = Nx.Defn.jit(init_fn, [], jit_opts) - - train_state = - case apply_callback(callbacks, train_state, jit_opts, :before_train) do - {:cont, train_state} -> - Enum.reduce_while(1..epochs, train_state, fn epoch, train_state -> - case apply_callback(callbacks, train_state, jit_opts, :before_epoch) do - {:cont, train_state} -> - {time, train_state} = - :timer.tc(&train_epoch/6, [ - step_fn, - train_state, - inputs, - targets, - callbacks, - jit_opts - ]) - - zero_metrics = Map.new(train_state[:metrics], fn {k, _} -> {k, 0.0} end) - - case apply_callback( - callbacks, - Map.put(train_state, :time, time), - jit_opts, - :after_epoch - ) do - {:cont, train_state} -> - train_state = %{ - Map.delete(train_state, :time) - | metrics: zero_metrics, - epoch: epoch, - epoch_step: 0, - epoch_loss: 0.0 - } - - {:cont, train_state} - - {:halt, train_state} -> - {:halt, train_state} - end - - {:halt, train_state} -> - {:halt, train_state} - end - end) - - {:halt, train_state} -> - train_state - end - - {_, train_state} = apply_callback(callbacks, train_state, jit_opts, :after_train) - - train_state - end - - ## Helpers - - defp train_epoch(step_fn, train_state, inputs, targets, callbacks, opts) do - dataset = Stream.zip(inputs, targets) - - Enum.reduce_while(dataset, train_state, fn {inp, tar}, train_state -> - case apply_callback(callbacks, train_state, opts, :before_batch) do - {:cont, train_state} -> - train_state = Nx.Defn.jit(step_fn, [train_state, inp, tar], opts) - apply_callback(callbacks, train_state, opts, :after_batch) - - {:halt, train_state} -> - {:halt, train_state} - end - end) - end - - defp apply_callback([], train_state, _, _), do: {:cont, train_state} - - defp apply_callback(callbacks, train_state, train_opts, event) do - result = - Enum.reduce_while(callbacks, train_state, fn - {callback, :all, opts}, train_state -> - case apply(callback, [train_state, event, opts ++ train_opts]) do - {:halt, acc} -> - {:halt, {:stopped, acc}} - - {:cont, acc} -> - {:cont, acc} - - other -> - raise "invalid return from callback #{inspect(other)}" - end - - {callback, on_event, opts}, train_state -> - if on_event == event do - case apply(callback, [train_state, event, opts ++ train_opts]) do - {:halt, acc} -> - {:halt, {:stopped, acc}} - - {:cont, acc} -> - {:cont, acc} - - other -> - raise "invalid return from callback #{inspect(other)}" - end - else - {:cont, train_state} - end - end) - - case result do - {:stopped, acc} -> - {:halt, acc} - - acc -> - {:cont, acc} - end - end -end diff --git a/lib/axon/training/callbacks.ex b/lib/axon/training/callbacks.ex deleted file mode 100644 index 0ab69bbb..00000000 --- a/lib/axon/training/callbacks.ex +++ /dev/null @@ -1,97 +0,0 @@ -defmodule Axon.Training.Callbacks do - @moduledoc """ - Axon training callbacks. - """ - - @doc """ - Standard IO Logger callback. - - Logs training results to standard out. - """ - def standard_io_logger(train_state, :before_train, opts) do - epochs = opts[:epochs] - metrics = Map.keys(train_state[:metrics]) - - IO.puts("Training model for #{epochs} epochs") - IO.puts("Metrics: #{inspect(metrics)}") - - {:cont, train_state} - end - - def standard_io_logger(train_state, :after_batch, opts) do - log_every = opts[:log_every] - - case log_every do - :none -> - :ok - - :every -> - log_batch( - train_state[:epoch], - train_state[:epoch_step], - train_state[:epoch_loss], - train_state[:metrics] - ) - - log_every when is_integer(log_every) -> - if Nx.remainder(train_state[:epoch_step], log_every) == Nx.tensor(0) do - log_batch( - train_state[:epoch], - train_state[:epoch_step], - train_state[:epoch_loss], - train_state[:metrics] - ) - end - end - - {:cont, train_state} - end - - def standard_io_logger(train_state, :after_epoch, _opts) do - epoch = Nx.to_scalar(train_state[:epoch]) - # Should this really be a part of train state, maybe an extra metadata argument? - time = train_state[:time] - epoch_loss = train_state[:epoch_loss] - - IO.puts("\n") - IO.puts("Epoch #{epoch + 1} time: #{time / 1_000_000}s") - IO.puts("Epoch #{epoch + 1} loss: #{:io_lib.format("~.5f", [Nx.to_scalar(epoch_loss)])}") - - train_state[:metrics] - |> Enum.each(fn {k, v} -> - IO.puts( - "Epoch #{epoch + 1} #{Atom.to_string(k)}: #{:io_lib.format("~.5f", [Nx.to_scalar(v)])}" - ) - end) - - IO.puts("\n") - - {:cont, train_state} - end - - def standard_io_logger(train_state, :after_train, _opts) do - IO.puts("Training finished") - {:cont, train_state} - end - - def standard_io_logger(train_state, _, _opts), do: {:cont, train_state} - - defp log_batch(epoch, step, loss, metrics) do - metrics = - metrics - |> Enum.map(fn {k, v} -> - "Average #{Atom.to_string(k)}: #{:io_lib.format("~.5f", [Nx.to_scalar(v)])}" - end) - - metrics = - Enum.join( - ["Average Loss: #{:io_lib.format("~.5f", [Nx.to_scalar(loss)])}" | metrics], - " - " - ) - - IO.write( - "\rEpoch #{Nx.to_scalar(epoch) + 1}, batch #{Nx.to_scalar(step)} - " <> - "#{metrics}" - ) - end -end diff --git a/lib/axon/training/step.ex b/lib/axon/training/step.ex deleted file mode 100644 index 8078b4b7..00000000 --- a/lib/axon/training/step.ex +++ /dev/null @@ -1,19 +0,0 @@ -defmodule Axon.Training.Step do - @moduledoc false - - # Training step which controls the Axon training loop - # Functions in the Training module work directly on this struct - - # Fields - # - # :init - Initialization of training state, training loops are modeled - # as a reduction over the training data with the training state as the - # accumulator. This will initialize the state of the accumulator - # - # :step - Training step. Performs a single step, updating the training - # state - # - # :callbacks - List of training callbacks. Performed at various times - # throughout training - defstruct [:init, :step, :callbacks] -end diff --git a/lib/axon/updates.ex b/lib/axon/updates.ex index 9b497ff8..5c749c35 100644 --- a/lib/axon/updates.ex +++ b/lib/axon/updates.ex @@ -753,7 +753,9 @@ defmodule Axon.Updates do """ defn apply_updates(params, updates) do transform({params, updates}, fn {params, updates} -> - deep_merge(params, updates, fn x, u -> Nx.add(x, u) end) + deep_merge(params, updates, fn x, u -> + Nx.add(x, Nx.as_type(u, Nx.type(x))) + end) end) end diff --git a/mix.lock b/mix.lock index 6521e173..68bb4d20 100644 --- a/mix.lock +++ b/mix.lock @@ -5,6 +5,6 @@ "makeup_elixir": {:hex, :makeup_elixir, "0.15.1", "b5888c880d17d1cc3e598f05cdb5b5a91b7b17ac4eaf5f297cb697663a1094dd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.1", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "db68c173234b07ab2a07f645a5acdc117b9f99d69ebf521821d89690ae6c6ec8"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, "nimble_parsec": {:hex, :nimble_parsec, "1.1.0", "3a6fca1550363552e54c216debb6a9e95bd8d32348938e13de5eda962c0d7f89", [:mix], [], "hexpm", "08eb32d66b706e913ff748f11694b17981c0b04a33ef470e33e11b3d3ac8f54b"}, - "nx": {:git, "https://github.com/elixir-nx/nx.git", "f4aea5cddfa413c39fbf1ec8b70d10800470f588", [sparse: "nx"]}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "ff12877d5a9e46437e22706333bf82a0753cda3c", [sparse: "nx"]}, "table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"}, } diff --git a/test/loop_test.exs b/test/loop_test.exs new file mode 100644 index 00000000..063010e4 --- /dev/null +++ b/test/loop_test.exs @@ -0,0 +1,192 @@ +defmodule Axon.LoopTest do + use ExUnit.Case, async: true + + alias Axon.Loop + alias Axon.Loop.State + + describe "factories" do + test "loop/3 creates a basic loop with defaults" do + step_fn = fn _, _ -> 1 end + + assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + Loop.loop(step_fn) + + assert init_fn.() == %{} + assert update_fn.({}, %{}) == 1 + assert transform.(%{}) == %{} + end + + test "trainer/3 returns a supervised training loop with basic case" do + model = Axon.input({nil, 1}) + + valid_axon_losses = [ + :binary_cross_entropy, + :categorical_cross_entropy, + :categorical_hinge, + :hinge, + :kl_divergence, + :log_cosh, + :mean_absolute_error, + :mean_squared_error, + :poisson, + :soft_margin + ] + + valid_axon_optimizers = + Axon.Optimizers.__info__(:functions) + |> Enum.map(fn {k, _} -> k end) + |> Enum.uniq() + + for loss <- valid_axon_losses do + for optimizer <- valid_axon_optimizers do + assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + Loop.trainer(model, loss, optimizer) + + assert %{model_state: %{}} = pstate = Nx.Defn.jit(init_fn, []) + + state = %State{step_state: pstate} + + assert %{model_state: %{}, y_true: tar, y_pred: pred} = + Nx.Defn.jit(update_fn, [{Nx.tensor([[1]]), Nx.tensor([[1]])}, pstate]) + + assert tar == Nx.tensor([[1]]) + assert pred == Nx.tensor([[1]]) + + assert transform.(state) == %{} + end + end + end + + test "trainer/3 returns a supervised training loop with custom loss" do + model = Axon.input({nil, 1}) + custom_loss_fn = fn _, _ -> Nx.tensor(5.0, backend: Nx.Defn.Expr) end + + assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + Loop.trainer(model, custom_loss_fn, :adam) + + assert %{model_state: %{}} = pstate = Nx.Defn.jit(init_fn, []) + + state = %State{step_state: pstate} + + assert %{model_state: %{}, y_true: tar, y_pred: pred, loss: loss} = + Nx.Defn.jit(update_fn, [{Nx.tensor([[1]]), Nx.tensor([[1]])}, pstate]) + + assert tar == Nx.tensor([[1]]) + assert pred == Nx.tensor([[1]]) + assert loss == Nx.tensor(5.0) + + assert transform.(state) == %{} + end + + test "trainer/3 returns a supervised training loop with custom optimizer" do + model = Axon.input({nil, 1}) + optimizer = Axon.Optimizers.rmsprop(1.0e-3) + + assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + Loop.trainer(model, :mean_squared_error, optimizer) + + assert %{model_state: %{}} = pstate = Nx.Defn.jit(init_fn, []) + + state = %State{step_state: pstate} + + assert %{model_state: %{}, y_true: tar, y_pred: pred} = + Nx.Defn.jit(update_fn, [{Nx.tensor([[1]]), Nx.tensor([[1]])}, pstate]) + + assert tar == Nx.tensor([[1]]) + assert pred == Nx.tensor([[1]]) + + assert transform.(state) == %{} + end + + test "trainer/3 returns a supervised training loop with custom model" do + model = Axon.input({nil, 1}) |> Axon.compile() + + assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + Loop.trainer(model, :mean_squared_error, :adam) + + assert %{model_state: %{}} = pstate = Nx.Defn.jit(init_fn, []) + + state = %State{step_state: pstate} + + assert %{model_state: %{}, y_true: tar, y_pred: pred} = + Nx.Defn.jit(update_fn, [{Nx.tensor([[1]]), Nx.tensor([[1]])}, pstate]) + + assert tar == Nx.tensor([[1]]) + assert pred == Nx.tensor([[1]]) + + assert transform.(state) == %{} + end + + test "trainer/3 returns a supervised training loop with multi-loss" do + model = {Axon.input({nil, 1}), Axon.input({nil, 1})} + + assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + Loop.trainer(model, [mean_squared_error: 0.5, mean_absolute_error: 0.5], :adam) + + assert %{model_state: %{}} = pstate = Nx.Defn.jit(init_fn, []) + + state = %State{step_state: pstate} + + assert %{model_state: %{}, y_true: tar, y_pred: pred, loss: loss} = + Nx.Defn.jit(update_fn, [ + {{Nx.tensor([[1]]), Nx.tensor([[1]])}, {Nx.tensor([[2]]), Nx.tensor([[2]])}}, + pstate + ]) + + assert tar == {Nx.tensor([[2]]), Nx.tensor([[2]])} + assert pred == {Nx.tensor([[1]]), Nx.tensor([[1]])} + assert loss == Nx.tensor(1.0) + + assert transform.(state) == %{} + end + + test "trainer/3 raises on bad inputs" do + assert_raise ArgumentError, ~r/Invalid/, fn -> + Axon.Loop.trainer(:foo, :mean_squared_error, :adam) + end + + assert_raise ArgumentError, ~r/Invalid/, fn -> + Axon.Loop.trainer(Axon.input({nil, 1}), :foo, :adam) + end + + assert_raise ArgumentError, ~r/Invalid/, fn -> + Axon.Loop.trainer(Axon.input({nil, 1}), :mean_squared_error, :foo) + end + end + + test "evaluator/3 returns a supervised evaluator loop" do + model = Axon.input({nil, 1}) + model_state = %{} + + assert %Loop{init: init_fn, step: update_fn, output_transform: transform} = + Loop.evaluator(model, model_state) + + assert %{y_true: _, y_pred: _} = pstate = Nx.Defn.jit(init_fn, []) + + state = %State{step_state: pstate, metrics: %{"my_metric" => {}}} + + assert %{y_true: tar, y_pred: pred} = + Nx.Defn.jit(update_fn, [{Nx.tensor([[1]]), Nx.tensor([[2]])}, pstate]) + + assert tar == Nx.tensor([[2]]) + assert pred == Nx.tensor([[1]]) + + assert transform.(state) == %{"my_metric" => {}} + end + end + + describe "looping" do + test "returns initial state with epochs 0" do + step_fn = fn _, _ -> 1 end + + state = + step_fn + |> Loop.loop() + |> Loop.run([], epochs: 0) + + assert %State{epoch: 0, iteration: 0, times: %{}, metrics: %{}, step_state: pstate} = state + + assert pstate == %{} + end + end +end diff --git a/test/mixed_precision_test.exs b/test/mixed_precision_test.exs index bf5b461b..51a86c74 100644 --- a/test/mixed_precision_test.exs +++ b/test/mixed_precision_test.exs @@ -3,7 +3,7 @@ defmodule MixedPrecisionTest do alias Axon.MixedPrecision.Policy alias Axon.MixedPrecision, as: AMP - alias Axon.Training.Step + alias Axon.Loop describe "creation and application" do test "create policy" do @@ -49,14 +49,13 @@ defmodule MixedPrecisionTest do mp_model = AMP.apply_policy(model, policy, except: [:batch_norm]) - %Step{init: init_fn, step: step_fn} = - Axon.Training.step(mp_model, :binary_cross_entropy, Axon.Optimizers.sgd(0.01)) + %Loop{init: init_fn, step: step_fn} = + Axon.Loop.trainer(mp_model, :binary_cross_entropy, Axon.Optimizers.sgd(0.01)) - state = init_fn.() + pstate = + Nx.Defn.jit(step_fn, [{Nx.random_uniform({1, 32}), Nx.random_uniform({1, 1})}, init_fn.()]) - state = Nx.Defn.jit(step_fn, [state, Nx.random_uniform({1, 32}), Nx.random_uniform({1, 1})]) - - params = state[:params] + params = pstate[:model_state] assert Nx.type(params["dense1"]["kernel"]) == {:bf, 16} assert Nx.type(params["dense1"]["bias"]) == {:bf, 16}