Skip to content

Commit

Permalink
Add migx ep fp8 int4 (#78)
Browse files Browse the repository at this point in the history
* Add fp8 and int4 types in supported list for Onnxruntime EP

* Add support for int4 inputs

Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands

* Add flag to allow for fp8 quantization through Onnxruntime API

* Add fp8 quantization to the compile stage of the MIGraphX EP

Mirror the same calibration code we use for int8 and just change which quantize we call through the MIGraphx API

* cleanup logging

* Cleanup and encapsulate quantization / compile functions

- Add additional flags for fp8 thats shared for int8

- Add lockout warning message when int8/fp8 used at the same time

* Run lintrunner pass

* Fix session options inputs + add better logging.

Previous runs using session options failed as we were missing pulling in inputs from the python interface. This plus additional logging allowed me to track what options were invoked via env and what were added during the start of an inference session

* Fix naming for save/load path varibles to be consistent with  enable.

* Print only env variables that are set as warnings

need this so the user knows there's any of the environment variables running in the background to ensure proper consistently between runs.

---------

Co-authored-by: Ted Themistokleous <[email protected]>
  • Loading branch information
TedThemistokleous and Ted Themistokleous committed Jan 29, 2025
1 parent bb933d4 commit 3ae2923
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 106 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ typedef struct OrtTensorRTProviderOptions {
typedef struct OrtMIGraphXProviderOptions {
int device_id; // hip device id.
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
Expand Down
298 changes: 194 additions & 104 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@ namespace onnxruntime {

namespace migraphx_env_vars {
static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE";
static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE";
static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE";
static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH";
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH";
static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE";

}; // namespace migraphx_env_vars
Expand All @@ -43,6 +44,7 @@ struct MIGraphXFuncState {
std::mutex* mgx_mu_ptr = nullptr;
bool no_input_shape = false;
bool fp16_enable = false;
bool fp8_enable = false;
bool int8_enable = false;
bool int8_calibration_cache_available = false;
std::unordered_map<std::string, float> dynamic_range_map;
Expand All @@ -60,6 +62,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info);
~MIGraphXExecutionProvider();

void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info);
void get_flags_from_env();
void print_migraphx_ep_flags();

Status Sync() const override;

Status OnRunStart(const onnxruntime::RunOptions& run_options) override;
Expand Down Expand Up @@ -92,6 +98,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
private:
MIGraphXExecutionProviderInfo info_;
bool fp16_enable_ = false;
bool fp8_enable_ = false;
bool int8_enable_ = false;
std::string int8_calibration_cache_name_;
bool int8_calibration_cache_available_ = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kFp8Enable = "migx_fp8_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
Expand Down Expand Up @@ -82,6 +83,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
Expand All @@ -100,6 +102,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
Expand All @@ -118,6 +121,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct MIGraphXExecutionProviderInfo {
std::string target_device;
OrtDevice::DeviceId device_id{0};
bool fp16_enable{false};
bool fp8_enable{false};
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct MIGraphX_Provider : Provider {
info.device_id = static_cast<OrtDevice::DeviceId>(options.device_id);
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.fp8_enable = options.migraphx_fp8_enable;
info.exhaustive_tune = options.migraphx_exhaustive_tune;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = "";
Expand Down Expand Up @@ -109,6 +110,7 @@ struct MIGraphX_Provider : Provider {
auto& migx_options = *reinterpret_cast<OrtMIGraphXProviderOptions*>(provider_options);
migx_options.device_id = internal_options.device_id;
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
migx_options.migraphx_fp8_enable = internal_options.fp8_enable;
migx_options.migraphx_int8_enable = internal_options.int8_enable;
migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune;

Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
0,
0,
0,
0,
nullptr,
1,
"./compiled_model.mxr",
Expand All @@ -880,6 +881,16 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_fp8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp8_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp8_enable = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
0,
0,
0,
0,
nullptr,
1,
"./compiled_model.mxr",
Expand Down

0 comments on commit 3ae2923

Please sign in to comment.