Skip to content

Commit

Permalink
use migraphx_provider_option contants
Browse files Browse the repository at this point in the history
  • Loading branch information
apwojcik committed Feb 3, 2025
1 parent 8c94c9d commit 90f06ee
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,6 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
#endif
} else if (type == kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
std::string calibration_table;
std::string save_model_path;
std::string load_model_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtMIGraphXProviderOptions params{
Expand All @@ -871,7 +868,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n");
}
} else if (option.first == "migraphx_fp16_enable") {
} else if (option.first == migraphx_provider_option::kFp16Enable) {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
Expand All @@ -881,7 +878,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_fp8_enable") {
} else if (option.first == migraphx_provider_option::kFp8Enable) {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp8_enable = true;
} else if (option.second == "False" || option.second == "false") {
Expand All @@ -891,7 +888,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
} else if (option.first == migraphx_provider_option::kInt8Enable) {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
} else if (option.second == "False" || option.second == "false") {
Expand All @@ -901,16 +898,15 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_calibration_table_name") {
} else if (option.first == migraphx_provider_option::kInt8CalibTable) {
if (!option.second.empty()) {
calibration_table = option.second;
params.migraphx_int8_calibration_table_name = calibration_table.c_str();
params.migraphx_int8_calibration_table_name = option.second.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a "
"file name i.e. 'cal_table'.\n");
}
} else if (option.first == "migraphx_use_native_calibration_table") {
} else if (option.first == migraphx_provider_option::kInt8UseNativeCalibTable) {
if (option.second == "True" || option.second == "true") {
params.migraphx_use_native_calibration_table = true;
} else if (option.second == "False" || option.second == "false") {
Expand All @@ -920,45 +916,43 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_compiled_model") {
} else if (option.first == migraphx_provider_option::kSaveCompiledModel ) {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
params.migraphx_save_compiled_model = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
params.migraphx_save_compiled_model = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_model_path") {
} else if (option.first == migraphx_provider_option::kSaveModelPath) {
if (!option.second.empty()) {
save_model_path = option.second;
params.migraphx_save_model_path = save_model_path.c_str();
params.migraphx_save_model_path = option.second.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_load_compiled_model") {
} else if (option.first == migraphx_provider_option::kLoadCompiledModel) {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
params.migraphx_load_compiled_model = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
params.migraphx_load_compiled_model = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_load_model_path") {
} else if (option.first == migraphx_provider_option::kLoadModelPath) {
if (!option.second.empty()) {
load_model_path = option.second;
params.migraphx_load_model_path = load_model_path.c_str();
params.migraphx_load_model_path = option.second.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_exhaustive_tune") {
} else if (option.first == migraphx_provider_option::kExhaustiveTune) {
if (option.second == "True" || option.second == "true") {
params.migraphx_exhaustive_tune = true;
} else if (option.second == "False" || option.second == "false") {
Expand Down

0 comments on commit 90f06ee

Please sign in to comment.