-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New training API #132
New training API #132
Conversation
Names are important. If the API is shifting to something more generalized for executing a model in a context other than just training, it should probably reflect that to prevent confusion. If I'm writing a validation or scoring function and I'm calling something called I do like the separation of concerns re Naming steps is going to be challenging imo since you're generalizing the "loop" so having a Ok, I've rambled enough... I'll think more re naming things and wait for feedback from others as well. |
What I have in mind here is to provide factories for common processing functions in the loop, e.g. def process_function(...):
...do something...
engine = Engine(process_function) |
examples/cifar10.exs
Outdated
# Configure default platform with accelerator precedence as tpu > cuda > rocm > host | ||
case EXLA.Client.get_supported_platforms() do | ||
%{'TPU' => _} -> | ||
Application.put_env(:exla, :clients, default: [platform: :tpu]) | ||
|
||
%{'CUDA' => _} -> | ||
Application.put_env(:exla, :clients, default: [platform: :cuda]) | ||
|
||
%{'ROCM' => _} -> | ||
Application.put_env(:exla, :clients, default: [platform: :rocm]) | ||
|
||
%{'Host' => _} -> | ||
Application.put_env(:exla, :clients, default: [platform: :host]) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't look very user-friendly to me. Would call it EXLA.Client.SET_supported_platforms
and put it where it belongs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EXLA.Client.get_supported_platforms()
is a wrapper around a TensorFlow function which queries the system for info about what accelerators are available. This is just a quick way to make sure the script uses a GPU if one is available so when running elixir examples/cifar10.exs
it will find and use the GPU by default. Ideally users will already have the application configured when creating their own scripts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, ideally, I'd have a single function for me to configure the app, instead of filling my every script with a wall of the same text. Also, will I have to modify every project, if there will be changes in supported platforms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should have a function that returns the default platform in order of preference. Something like this:
EXLA.Client.get_preferred_platform([:tpu, :cuda, :rocm, :host])
We can also have one that accepts a client name:
EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host])
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would also be nice if EXLA.Client.get_supported_platforms()
returned atoms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry to be clear, to configure your apps, you would either set config in runtime/config.exs
or use the Application.put_env
approach in your scripts. This is not intended to be a public API. This is under the assumption that the example scripts are run using elixir ...
so there is no project configuration set. In that case, this ensures any available accelerator is used.
As long as your projects use one of the maintained platforms host
, cuda
, rocm
, or tpu
, then you won't need to worry about changes in supported platforms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EXLA.Client.set_preferred_platform(:default, [:tpu, :cuda, :rocm, :host])
This is what I would go with
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@josevalim
If you're asking me, I'm not an expert. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry to be clear, to configure your apps, you would either set config in
runtime/config.exs
or use theApplication.put_env
approach in your scripts.
Sorry, it's just that I'm not used to the idea that I have to set configs for a script to work.
This is not intended to be a public API. This is under the assumption that the example scripts are run using
elixir ...
so there is no project configuration set. In that case, this ensures any available accelerator is used.
There's no place for "not a public API" in an ordinary cifar10 example.
And, how does adding a shortcut for a user harms anybody?
This PR implements a new training API modeled after PyTorch Ignite which is a little more thought out and extensible than the previous API. When finished, this PR will resolve #22, #78, and #80; and provide at least what I think will be an easier path forward for #25, #79, and #101.
The basic idea is that
Axon.Training
implements conveniences and functions for running loops or reductions in the same way that PyTorch Ignite does. A "loop" is a struct which controls the inner workings of the loop, currently it contains the fields:process
,metrics
,handlers
.process
is a "process" struct which represents the loop processing function. A loop processing function accepts the loop state and some data, and returns an updated process state. The process struct contains an init and step function for initializing and updating the process state respectively.metrics
is a map of metrics to attach to the process state. This means you can apply metrics on predictions, parameters, optimizer state, etc.handlers
is a map of events and their specific handlers.handlers
are meant to be side-effecting, so they don't alter training state in anyway. This might change in the future.The current API exposes a
run
function that runs a given loop. It doesn't make any assumptions about whether we're doing training, validation, testing, etc. Ignite provides examples of some of the other cool things you can do with their Engines (like processing datasets and more).There's also a
step
function which creates a loop from a model, loss, and optimizer. I plan on adding additional loop factories in the future. Ideally, we change the name of this to something likesupervised_train_step
or justtrain_step
and then we can havevalidation_step
, etc.Training doesn't change all that much:
But, now we can do things like:
The above example can be cleaned up even further with some additional factories and refining of the API.
Now the
Axon.Training
API is less of a training API and more of a loop API to be used in training. Not sure if it makes sense to continue to call itAxon.Training
anymore becasue it theoretically could be used outside of training, but not sure if it really matters.This is still WIP as there are a lot of things I want to play with/add. I also plan on adding a number of integration training tests with this new API.