Skip to content

Commit

Permalink
provider_options
Browse files Browse the repository at this point in the history
  • Loading branch information
apwojcik committed Feb 3, 2025
1 parent b6ab00f commit 090f327
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,6 @@ const EnumNameMapping<ArenaExtendStrategy> arena_extend_strategy_mapping{
{ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"},
};

namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kFp8Enable = "migx_fp8_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";
constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
constexpr const char* kSaveModelPath = "migx_save_model_name";
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";
constexpr const char* kExhaustiveTune = "migx_exhaustive_tune";
constexpr const char* kMemLimit = "migx_mem_limit";
constexpr const char* kArenaExtendStrategy = "migx_arena_extend_strategy";
constexpr const char* kGpuExternalAlloc = "migx_external_alloc";
constexpr const char* kGpuExternalFree = "migx_external_free";
constexpr const char* kGpuExternalEmptyCache = "migx_external_empty_cache";

} // namespace provider_option_names
} // namespace migraphx

MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) {
MIGraphXExecutionProviderInfo info{};
void* alloc = nullptr;
Expand All @@ -47,7 +25,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
ORT_THROW_IF_ERROR(
ProviderOptionsParser{}
.AddValueParser(
migraphx::provider_option_names::kDeviceId,
migraphx_provider_option::kDeviceId,
[&info](const std::string& value_str) -> Status {
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id));
int num_devices{};
Expand All @@ -59,37 +37,37 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalAlloc,
migraphx_provider_option::kGpuExternalAlloc,
[&alloc](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
alloc = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalFree,
migraphx_provider_option::kGpuExternalFree,
[&free](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
free = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddValueParser(
migraphx::provider_option_names::kGpuExternalEmptyCache,
migraphx_provider_option::kGpuExternalEmptyCache,
[&empty_cache](const std::string& value_str) -> Status {
size_t address;
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address));
empty_cache = reinterpret_cast<void*>(address);
return Status::OK();
})
.AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune)
.AddAssignmentToReference(migraphx::provider_option_names::kMemLimit, info.mem_limit)
.AddAssignmentToEnumReference(migraphx::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy)
.AddAssignmentToReference(migraphx_provider_option::kFp16Enable, info.fp16_enable)
.AddAssignmentToReference(migraphx_provider_option::kFp8Enable, info.fp8_enable)
.AddAssignmentToReference(migraphx_provider_option::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx_provider_option::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx_provider_option::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, info.exhaustive_tune)
.AddAssignmentToReference(migraphx_provider_option::kMemLimit, info.mem_limit)
.AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy)
.Parse(options));

MIGraphXExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache};
Expand All @@ -100,34 +78,33 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions

ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) {
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)},
{migraphx::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{migraphx::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
{migraphx::provider_option_names::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache))},
{migraphx::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
{migraphx_provider_option::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx_provider_option::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx_provider_option::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)},
{migraphx_provider_option::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx_provider_option::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx_provider_option::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx_provider_option::kMemLimit, MakeStringWithClassicLocale(info.mem_limit)},
{migraphx_provider_option::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.alloc))},
{migraphx_provider_option::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.free))},
{migraphx_provider_option::kGpuExternalEmptyCache, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.external_allocator_info.empty_cache))},
{migraphx_provider_option::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
{migraphx_provider_option::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)},
};
return options;
}

ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) {
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)},
{migraphx::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.migraphx_arena_extend_strategy))},
{migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
{migraphx_provider_option::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx_provider_option::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx_provider_option::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)},
{migraphx_provider_option::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx_provider_option::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx_provider_option::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx_provider_option::kMemLimit, MakeStringWithClassicLocale(info.migraphx_mem_limit)},
{migraphx_provider_option::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast<onnxruntime::ArenaExtendStrategy>(info.migraphx_arena_extend_strategy))},
{migraphx_provider_option::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,25 @@

namespace onnxruntime {

namespace migraphx_provider_option {
constexpr auto kDeviceId = "device_id";
constexpr auto kFp16Enable = "migraphx_fp16_enable";
constexpr auto kFp8Enable = "migraphx_fp8_enable";
constexpr auto kInt8Enable = "migraphx_int8_enable";
constexpr auto kInt8CalibTable = "migraphx_int8_calibration_table_name";
constexpr auto kInt8UseNativeCalibTable = "migraphx_int8_use_native_calibration_table";
constexpr auto kSaveCompiledModel = "migraphx_save_compiled_model";
constexpr auto kSaveModelPath = "migraphx_save_model_name";
constexpr auto kLoadCompiledModel = "migraphx_load_compiled_model";
constexpr auto kLoadModelPath = "migraphx_load_model_name";
constexpr auto kExhaustiveTune = "migraphx_exhaustive_tune";
constexpr auto kMemLimit = "migraphx_mem_limit";
constexpr auto kArenaExtendStrategy = "migraphx_arena_extend_strategy";
constexpr auto kGpuExternalAlloc = "migraphx_external_alloc";
constexpr auto kGpuExternalFree = "migraphx_external_free";
constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache";
} // namespace migraphx_provider_option

// Information needed to construct MIGraphX execution providers.
struct MIGraphXExecutionProviderExternalAllocatorInfo {
void* alloc{nullptr};
Expand Down

0 comments on commit 090f327

Please sign in to comment.