Skip to content

Commit

Permalink
New training API (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Oct 19, 2021
1 parent 8916210 commit 1a215c7
Show file tree
Hide file tree
Showing 19 changed files with 1,796 additions and 772 deletions.
93 changes: 62 additions & 31 deletions examples/cifar10.exs
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
Mix.install([
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"},
{:exla, github: "elixir-nx/exla", sparse: "exla"},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.1"},
{:scidata, "~> 0.1.1"}
])

# Configure default platform with accelerator precedence as tpu > cuda > rocm > host
EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host])

defmodule Cifar do
require Axon
alias Axon.Loop.State

defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 3, 32, 32})
|> Nx.divide(255.0)
|> Nx.to_batched_list(32)
|> Enum.split(1500)
end

defp transform_labels({bin, type, _}) do
Expand All @@ -22,15 +27,7 @@ defmodule Cifar do
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
end

defp view_images(images, {start_index, len}) do
images
|> hd()
|> Nx.slice_axis(start_index, len, 0)
|> Nx.mean(axes: [1], keep_axes: true)
|> Nx.to_heatmap()
|> IO.inspect
|> Enum.split(1500)
end

defp build_model(input_shape) do
Expand All @@ -39,43 +36,77 @@ defmodule Cifar do
|> Axon.batch_norm()
|> Axon.max_pool(kernel_size: {2, 2})
|> Axon.conv(64, kernel_size: {3, 3}, activation: :relu)
|> Axon.spatial_dropout()
|> Axon.batch_norm()
|> Axon.max_pool(kernel_size: {2, 2})
|> Axon.conv(32, kernel_size: {3, 3}, activation: :relu)
|> Axon.batch_norm()
|> Axon.flatten()
|> Axon.dense(64, activation: :relu)
|> Axon.dropout()
|> Axon.dropout(rate: 0.5)
|> Axon.dense(10, activation: :softmax)
end

defp train_model(model, {train_images, train_labels}, epochs) do
defp log_metrics(
%State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state,
mode
) do
loss =
case mode do
:train ->
%{loss: loss} = pstate
"Loss: #{:io_lib.format('~.5f', [Nx.to_scalar(loss)])}"

:test ->
""
end

metrics =
metrics
|> Enum.map(fn {k, v} -> "#{k}: #{:io_lib.format('~.5f', [Nx.to_scalar(v)])}" end)
|> Enum.join(" ")

IO.write("\rEpoch: #{Nx.to_scalar(epoch)}, Batch: #{Nx.to_scalar(iter)}, #{loss} #{metrics}")

{:continue, state}
end

defp train_model(model, train_images, train_labels, epochs) do
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :train), every: 50)
|> Axon.Loop.run(Stream.zip(train_images, train_labels), epochs: epochs, compiler: EXLA)
end

defp test_model(model, model_state, test_images, test_labels) do
model
|> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.sgd(0.01), metrics: [:accuracy])
|> Axon.Training.train(train_images, train_labels, epochs: epochs, compiler: EXLA)
|> Nx.backend_transfer()
|> Axon.Loop.evaluator(model_state)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :test), every: 50)
|> Axon.Loop.run(Stream.zip(test_images, test_labels), compiler: EXLA)
end

def run do
{train_images, train_labels} = Scidata.CIFAR10.download(transform_images: &transform_images/1, transform_labels: &transform_labels/1)
{images, labels} =
Scidata.CIFAR10.download(
transform_images: &transform_images/1,
transform_labels: &transform_labels/1
)

{train_images, test_images} = images
{train_labels, test_labels} = labels

view_images(train_images, {0, 1})
model = build_model({nil, 3, 32, 32}) |> IO.inspect()

model = build_model({nil, 3, 32, 32}) |> IO.inspect
IO.write("\n\n Training Model \n\n")

final_training_state =
model_state =
model
|> train_model({train_images, train_labels}, 20)
|> IO.inspect()
|> train_model(train_images, train_labels, 10)

test_images = train_images |> hd() |> Nx.slice_axis(10, 3, 0)
view_images(train_images, {10, 3})
IO.write("\n\n Testing Model \n\n")

model
|> Axon.predict(final_training_state[:params], test_images)
|> Nx.argmax(axis: -1)
|> IO.inspect
test_model(model, model_state, test_images, test_labels)

IO.write("\n\n")
end
end

Expand Down
40 changes: 35 additions & 5 deletions examples/fashionmnist_autoencoder.exs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
Mix.install([
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"},
{:exla, github: "elixir-nx/exla", sparse: "exla"},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.1"},
])

# Configure default platform with accelerator precedence as tpu > cuda > rocm > host
EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host])

defmodule Fashionmist do
require Axon
alias Axon.Loop.State

defmodule Autoencoder do
defp encoder(x, latent_dim) do
Expand Down Expand Up @@ -36,18 +40,44 @@ defmodule Fashionmist do
|> Nx.to_batched_list(32)
end

defp log_metrics(
%State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state,
mode
) do
loss =
case mode do
:train ->
%{loss: loss} = pstate
"Loss: #{:io_lib.format('~.5f', [Nx.to_scalar(loss)])}"

:test ->
""
end

metrics =
metrics
|> Enum.map(fn {k, v} -> "#{k}: #{:io_lib.format('~.5f', [Nx.to_scalar(v)])}" end)
|> Enum.join(" ")

IO.write("\rEpoch: #{Nx.to_scalar(epoch)}, Batch: #{Nx.to_scalar(iter)}, #{loss} #{metrics}")

{:continue, state}
end

defp train_model(model, train_images, epochs) do
model
|> Axon.Training.step(:mean_squared_error, Axon.Optimizers.adam(0.01), metrics: [:mean_absolute_error])
|> Axon.Training.train(train_images, train_images, epochs: epochs, compiler: EXLA)
|> Axon.Loop.trainer(:mean_squared_error, :adam)
|> Axon.Loop.metric(:mean_absolute_error, "Error")
|> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :train), every: 50)
|> Axon.Loop.run(Stream.zip(train_images, train_images), epochs: epochs, compiler: EXLA)
end

def run do
{train_images, _} = Scidata.FashionMNIST.download(transform_images: &transform_images/1)

model = Autoencoder.build_model({nil, 1, 28, 28}, 64) |> IO.inspect

final_training_state = train_model(model, train_images, 5)
model_state = train_model(model, train_images, 5)

sample_image =
train_images
Expand All @@ -58,7 +88,7 @@ defmodule Fashionmist do
sample_image |> Nx.to_heatmap() |> IO.inspect

model
|> Axon.predict(final_training_state[:params], sample_image, compiler: EXLA)
|> Axon.predict(model_state, sample_image, compiler: EXLA)
|> Nx.to_heatmap()
|> IO.inspect()
end
Expand Down
90 changes: 65 additions & 25 deletions examples/mnist.exs
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
Mix.install([
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon", branch: "main"},
{:exla, github: "elixir-nx/exla", sparse: "exla"},
{:axon, "~> 0.1.0-dev", github: "elixir-nx/axon"},
{:exla, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "exla"},
{:nx, "~> 0.1.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true},
{:scidata, "~> 0.1.1"},
{:scidata, "~> 0.1.1"}
])

# Configure default platform with accelerator precedence as tpu > cuda > rocm > host
EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host])

defmodule Mnist do
require Axon

alias Axon.Loop.State

defp transform_images({bin, type, shape}) do
bin
|> Nx.from_binary(type)
|> Nx.reshape({elem(shape, 0), 784})
|> Nx.divide(255.0)
|> Nx.to_batched_list(32)
# Test split
|> Enum.split(1750)
end

defp transform_labels({bin, type, _}) do
Expand All @@ -22,15 +29,8 @@ defmodule Mnist do
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
|> Nx.to_batched_list(32)
end

defp view_images(images, {start_index, len}) do
images
|> hd()
|> Nx.slice_axis(start_index, len, 0)
|> Nx.reshape({:auto, 28, 28})
|> Nx.to_heatmap()
|> IO.inspect
# Test split
|> Enum.split(1750)
end

defp build_model(input_shape) do
Expand All @@ -40,30 +40,70 @@ defmodule Mnist do
|> Axon.dense(10, activation: :softmax)
end

defp train_model(model, {train_images, train_labels}, epochs) do
defp log_metrics(
%State{epoch: epoch, iteration: iter, metrics: metrics, process_state: pstate} = state,
mode
) do
loss =
case mode do
:train ->
%{loss: loss} = pstate
"Loss: #{:io_lib.format('~.5f', [Nx.to_scalar(loss)])}"

:test ->
""
end

metrics =
metrics
|> Enum.map(fn {k, v} -> "#{k}: #{:io_lib.format('~.5f', [Nx.to_scalar(v)])}" end)
|> Enum.join(" ")

IO.write("\rEpoch: #{Nx.to_scalar(epoch)}, Batch: #{Nx.to_scalar(iter)}, #{loss} #{metrics}")

{:continue, state}
end

defp train_model(model, train_images, train_labels, epochs) do
model
|> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005), metrics: [:accuracy])
|> Axon.Training.train(train_images, train_labels, epochs: epochs, compiler: EXLA, log_every: 100)
|> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :train), every: 50)
|> Axon.Loop.run(Stream.zip(train_images, train_labels), epochs: epochs, compiler: EXLA)
end

defp test_model(model, model_state, test_images, test_labels) do
model
|> Axon.Loop.evaluator(model_state)
|> Axon.Loop.metric(:accuracy, "Accuracy")
|> Axon.Loop.handle(:iteration_completed, &log_metrics(&1, :test), every: 50)
|> Axon.Loop.run(Stream.zip(test_images, test_labels), compiler: EXLA)
end

def run do
{train_images, train_labels} = Scidata.MNIST.download(transform_images: &transform_images/1, transform_labels: &transform_labels/1)
{images, labels} =
Scidata.MNIST.download(
transform_images: &transform_images/1,
transform_labels: &transform_labels/1
)

{train_images, test_images} = images
{train_labels, test_labels} = labels

view_images(train_images, {0, 1})
model = build_model({nil, 784}) |> IO.inspect()

model = build_model({nil, 784}) |> IO.inspect
IO.write("\n\n Training Model \n\n")

final_training_state =
model_state =
model
|> train_model({train_images, train_labels}, 10)
|> train_model(train_images, train_labels, 5)

test_images = train_images |> hd() |> Nx.slice_axis(10, 3, 0)
view_images(train_images, {10, 3})
IO.write("\n\n Testing Model \n\n")

model
|> Axon.predict(final_training_state[:params], test_images)
|> Nx.argmax(axis: -1)
|> IO.inspect
|> test_model(model_state, test_images, test_labels)

IO.write("\n\n")
end
end

Expand Down
Loading

0 comments on commit 1a215c7

Please sign in to comment.