Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New training API #132

Merged
merged 21 commits into from
Oct 19, 2021
Merged

New training API #132

merged 21 commits into from
Oct 19, 2021

Conversation

seanmor5
Copy link
Contributor

This PR implements a new training API modeled after PyTorch Ignite which is a little more thought out and extensible than the previous API. When finished, this PR will resolve #22, #78, and #80; and provide at least what I think will be an easier path forward for #25, #79, and #101.

The basic idea is that Axon.Training implements conveniences and functions for running loops or reductions in the same way that PyTorch Ignite does. A "loop" is a struct which controls the inner workings of the loop, currently it contains the fields: process, metrics, handlers. process is a "process" struct which represents the loop processing function. A loop processing function accepts the loop state and some data, and returns an updated process state. The process struct contains an init and step function for initializing and updating the process state respectively.

metrics is a map of metrics to attach to the process state. This means you can apply metrics on predictions, parameters, optimizer state, etc.

handlers is a map of events and their specific handlers. handlers are meant to be side-effecting, so they don't alter training state in anyway. This might change in the future.

The current API exposes a run function that runs a given loop. It doesn't make any assumptions about whether we're doing training, validation, testing, etc. Ignite provides examples of some of the other cool things you can do with their Engines (like processing datasets and more).

There's also a step function which creates a loop from a model, loss, and optimizer. I plan on adding additional loop factories in the future. Ideally, we change the name of this to something like supervised_train_step or just train_step and then we can have validation_step, etc.

Training doesn't change all that much:

model
|> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adam(5.0e-3)
|> Axon.Training.metric(:accuracy, "Accuracy")
|> Axon.Training.run()

But, now we can do things like:

def test_model(model, model_state, {test_images, test_labels}) do
    process_fn = fn _, {inp, _} ->
      preds = Axon.predict(model, model_state, inp)
      %{predictions: preds, loss: Nx.tensor(0.0, backend: Nx.Defn.Expr)}
    end

    process = %Axon.Training.Process{init: init_fn, step: process_fn}
    loop = %Axon.Training.Loop{process: process} |> Axon.Training.metric(:accuracy, "Accuracy")
    Axon.Training.run(loop, Stream.zip(test_images, test_labels), compiler: EXLA)
end

The above example can be cleaned up even further with some additional factories and refining of the API.

Now the Axon.Training API is less of a training API and more of a loop API to be used in training. Not sure if it makes sense to continue to call it Axon.Training anymore becasue it theoretically could be used outside of training, but not sure if it really matters.

This is still WIP as there are a lot of things I want to play with/add. I also plan on adding a number of integration training tests with this new API.

@arpieb
Copy link
Contributor

arpieb commented Oct 13, 2021

Now the Axon.Training API is less of a training API and more of a loop API to be used in training. Not sure if it makes sense to continue to call it Axon.Training anymore becasue it theoretically could be used outside of training, but not sure if it really matters.

Names are important. If the API is shifting to something more generalized for executing a model in a context other than just training, it should probably reflect that to prevent confusion. If I'm writing a validation or scoring function and I'm calling something called Axon.Training, it just doesn't smell right. That being said, I'm at a loss atm as to what might be a better name for that module...

I do like the separation of concerns re step, metrics and handlers - it feels more composable.

Naming steps is going to be challenging imo since you're generalizing the "loop" so having a training_step and validation_step - what would you use for a plain inference step? maybe there's a training_step that's understood to update the model params and another that is strictly inference like predict_step? When you get down to it, a validation step is simply an inference step with a dataset sep from the training set.

Ok, I've rambled enough... I'll think more re naming things and wait for feedback from others as well.

@seanmor5
Copy link
Contributor Author

Naming steps is going to be challenging imo since you're generalizing the "loop" so having a training_step and validation_step - what would you use for a plain inference step? maybe there's a training_step that's understood to update the model params and another that is strictly inference like predict_step? When you get down to it, a validation step is simply an inference step with a dataset sep from the training set.

What I have in mind here is to provide factories for common processing functions in the loop, e.g. train_step and eval_step (which can be used for validation/testing, the model runs in inference mode and we return predictions which we attach metrics to). There would still be factories for creating processes for anything, such as process or just generic step. Ignite uses objects:

def process_function(...):
  ...do something...

engine = Engine(process_function)

@josevalim
Copy link
Contributor

josevalim commented Oct 13, 2021

This is looking solid!

@arpieb @seanmor5 regarding the name, what about Axon.Loop? Then you can use Axon.Loop.training to generate a training loop. Or you can call it Axon.Step with Axon.Step.training.

@seanmor5 seanmor5 marked this pull request as ready for review October 18, 2021 15:12
Comment on lines 8 to 21
# Configure default platform with accelerator precedence as tpu > cuda > rocm > host
case EXLA.Client.get_supported_platforms() do
%{'TPU' => _} ->
Application.put_env(:exla, :clients, default: [platform: :tpu])

%{'CUDA' => _} ->
Application.put_env(:exla, :clients, default: [platform: :cuda])

%{'ROCM' => _} ->
Application.put_env(:exla, :clients, default: [platform: :rocm])

%{'Host' => _} ->
Application.put_env(:exla, :clients, default: [platform: :host])
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look very user-friendly to me. Would call it EXLA.Client.SET_supported_platforms and put it where it belongs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXLA.Client.get_supported_platforms() is a wrapper around a TensorFlow function which queries the system for info about what accelerators are available. This is just a quick way to make sure the script uses a GPU if one is available so when running elixir examples/cifar10.exs it will find and use the GPU by default. Ideally users will already have the application configured when creating their own scripts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, ideally, I'd have a single function for me to configure the app, instead of filling my every script with a wall of the same text. Also, will I have to modify every project, if there will be changes in supported platforms?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have a function that returns the default platform in order of preference. Something like this:

EXLA.Client.get_preferred_platform([:tpu, :cuda, :rocm, :host])

We can also have one that accepts a client name:

EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host])

WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be nice if EXLA.Client.get_supported_platforms() returned atoms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry to be clear, to configure your apps, you would either set config in runtime/config.exs or use the Application.put_env approach in your scripts. This is not intended to be a public API. This is under the assumption that the example scripts are run using elixir ... so there is no project configuration set. In that case, this ensures any available accelerator is used.

As long as your projects use one of the maintained platforms host, cuda, rocm, or tpu, then you won't need to worry about changes in supported platforms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host])

This is what I would go with

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josevalim
If you're asking me, I'm not an expert. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry to be clear, to configure your apps, you would either set config in runtime/config.exs or use the Application.put_env approach in your scripts.

Sorry, it's just that I'm not used to the idea that I have to set configs for a script to work.

This is not intended to be a public API. This is under the assumption that the example scripts are run using elixir ... so there is no project configuration set. In that case, this ensures any available accelerator is used.

There's no place for "not a public API" in an ordinary cifar10 example.
And, how does adding a shortcut for a user harms anybody?

@seanmor5 seanmor5 merged commit 1a215c7 into main Oct 19, 2021
@seanmor5 seanmor5 deleted the sm-new-training-engine branch October 19, 2021 15:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Integrate validation and testing into training API
4 participants