Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move provider opion names to header file #86

Open
wants to merge 2 commits into
base: rocm6.4_internal_testing
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,6 @@ const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
};

namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kFp8Enable = "migx_fp8_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
constexpr const char* kSaveModelPath = "migx_save_model_name";
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";
constexpr const char* kExhaustiveTune = "migx_exhaustive_tune";
constexpr const char* kMemLimit = "migx_mem_limit";
constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy";
constexpr const char* kGpuExternalAlloc = "migx_external_alloc";
constexpr const char* kGpuExternalFree = "migx_external_free";
constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache";

} // namespace provider_option_names
} // namespace migraphx

MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
MIGraphXExecutionProviderInfo info{};
void* alloc = nullptr;
Expand All @@ -47,7 +25,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
migraphx::provider_option_names::kDeviceId,
migraphx_provider_option::kDeviceId,
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id));
int num_devices{};
Expand All @@ -59,37 +37,37 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalAlloc,
migraphx_provider_option::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
alloc = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalFree,
migraphx_provider_option::kGpuExternalFree,
[&free](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
free = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalEmptyCache,
migraphx_provider_option::kGpuExternalEmptyCache,
[&empty_cache](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
empty_cache = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune)
.AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit)
.AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy)
.AddAssignmentToReference(migraphx_provider_option::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx_provider_option::kFp8Enable, info.fp8_enable)
.AddAssignmentToReference(migraphx_provider_option::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx_provider_option::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx_provider_option::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, info.exhaustive_tune)
.AddAssignmentToReference(migraphx_provider_option::kMemLimit, info.mem_limit)
.AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy)
.Parse(options));

MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
Expand All @@ -100,34 +78,33 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions

ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) {
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)},
{migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
{migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache))},
{migraphx::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
{migraphx_provider_option::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx_provider_option::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx_provider_option::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)},
{migraphx_provider_option::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx_provider_option::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx_provider_option::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx_provider_option::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)},
{migraphx_provider_option::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{migraphx_provider_option::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
{migraphx_provider_option::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache))},
{migraphx_provider_option::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{migraphx_provider_option::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
};
return options;
}

ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) {
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)},
{migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.migraphx_arena_extend_strategy))},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
{migraphx_provider_option::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx_provider_option::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx_provider_option::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)},
{migraphx_provider_option::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx_provider_option::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx_provider_option::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx_provider_option::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)},
{migraphx_provider_option::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.migraphx_arena_extend_strategy))},
{migraphx_provider_option::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@

namespace onnxruntime {

namespace migraphx_provider_option {
constexpr auto kDeviceId = "device_id";
constexpr auto kFp16Enable = "migraphx_fp16_enable";
constexpr auto kFp8Enable = "migraphx_fp8_enable";
constexpr auto kInt8Enable = "migraphx_int8_enable";
constexpr auto kInt8CalibTable = "migraphx_int8_calibration_table_name";
constexpr auto kInt8UseNativeCalibTable = "migraphx_int8_use_native_calibration_table";
constexpr auto kSaveCompiledModel = "migraphx_save_compiled_model";
constexpr auto kSaveModelPath = "migraphx_save_model_name";
constexpr auto kLoadCompiledModel = "migraphx_load_compiled_model";
constexpr auto kLoadModelPath = "migraphx_load_model_name";
constexpr auto kExhaustiveTune = "migraphx_exhaustive_tune";
constexpr auto kMemLimit = "migraphx_mem_limit";
constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy";
constexpr auto kGpuExternalAlloc = "migraphx_external_alloc";
constexpr auto kGpuExternalFree = "migraphx_external_free";
constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache";
} // namespace migraphx_provider_option

// Information needed to construct MIGraphX execution providers.
struct MIGraphXExecutionProviderExternalAllocatorInfo {
void* alloc{nullptr};
Expand Down
40 changes: 17 additions & 23 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -845,9 +845,6 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
#endif
} else if (type == kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
std::string calibration_table;
std::string save_model_path;
std::string load_model_path;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtMIGraphXProviderOptions params{
Expand All @@ -871,7 +868,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n");
}
} else if (option.first == "migraphx_fp16_enable") {
} else if (option.first == migraphx_provider_option::kFp16Enable) {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
} else if (option.second == "False" || option.second == "false") {
Expand All @@ -881,7 +878,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migraphx_fp16_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_fp8_enable") {
} else if (option.first == migraphx_provider_option::kFp8Enable) {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp8_enable = true;
} else if (option.second == "False" || option.second == "false") {
Expand All @@ -891,7 +888,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migraphx_fp8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
} else if (option.first == migraphx_provider_option::kInt8Enable) {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
} else if (option.second == "False" || option.second == "false") {
Expand All @@ -901,16 +898,15 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migraphx_int8_enable' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_calibration_table_name") {
} else if (option.first == migraphx_provider_option::kInt8CalibTable) {
if (!option.second.empty()) {
calibration_table = option.second;
params.migraphx_int8_calibration_table_name = calibration_table.c_str();
params.migraphx_int8_calibration_table_name = option.second.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_int8_calibration_table_name' should be a "
"file name i.e. 'cal_table'.\n");
}
} else if (option.first == "migraphx_use_native_calibration_table") {
} else if (option.first == migraphx_provider_option::kInt8UseNativeCalibTable) {
if (option.second == "True" || option.second == "true") {
params.migraphx_use_native_calibration_table = true;
} else if (option.second == "False" || option.second == "false") {
Expand All @@ -920,45 +916,43 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migraphx_use_native_calibration_table' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_compiled_model") {
} else if (option.first == migraphx_provider_option::kSaveCompiledModel ) {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
params.migraphx_save_compiled_model = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
params.migraphx_save_compiled_model = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_save_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_save_model_path") {
} else if (option.first == migraphx_provider_option::kSaveModelPath) {
if (!option.second.empty()) {
save_model_path = option.second;
params.migraphx_save_model_path = save_model_path.c_str();
params.migraphx_save_model_path = option.second.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_save_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_load_compiled_model") {
} else if (option.first == migraphx_provider_option::kLoadCompiledModel) {
if (option.second == "True" || option.second == "true") {
params.migraphx_fp16_enable = true;
params.migraphx_load_compiled_model = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
params.migraphx_load_compiled_model = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_load_compiled_model' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_load_model_path") {
} else if (option.first == migraphx_provider_option::kLoadModelPath) {
if (!option.second.empty()) {
load_model_path = option.second;
params.migraphx_load_model_path = load_model_path.c_str();
params.migraphx_load_model_path = option.second.c_str();
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_load_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_exhaustive_tune") {
} else if (option.first == migraphx_provider_option::kExhaustiveTune) {
if (option.second == "True" || option.second == "true") {
params.migraphx_exhaustive_tune = true;
} else if (option.second == "False" || option.second == "false") {
Expand Down
Loading