From 3ae29230edaddb7801e32580568941f49d7a2f29 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Wed, 29 Jan 2025 10:55:46 -0500 Subject: [PATCH] Add migx ep fp8 int4 (#78) * Add fp8 and int4 types in supported list for Onnxruntime EP * Add support for int4 inputs Map things to int8 right now as we don't explicitly set an int4 input type and pack/unpack int4 operands * Add flag to allow for fp8 quantization through Onnxruntime API * Add fp8 quantization to the compile stage of the MIGraphX EP Mirror the same calibration code we use for int8 and just change which quantize we call through the MIGraphx API * cleanup logging * Cleanup and encapsulate quantization / compile functions - Add additional flags for fp8 thats shared for int8 - Add lockout warning message when int8/fp8 used at the same time * Run lintrunner pass * Fix session options inputs + add better logging. Previous runs using session options failed as we were missing pulling in inputs from the python interface. This plus additional logging allowed me to track what options were invoked via env and what were added during the start of an inference session * Fix naming for save/load path varibles to be consistent with enable. * Print only env variables that are set as warnings need this so the user knows there's any of the environment variables running in the background to ensure proper consistently between runs. --------- Co-authored-by: Ted Themistokleous --- .../core/session/onnxruntime_c_api.h | 1 + .../migraphx/migraphx_execution_provider.cc | 298 ++++++++++++------ .../migraphx/migraphx_execution_provider.h | 11 +- .../migraphx_execution_provider_info.cc | 4 + .../migraphx_execution_provider_info.h | 1 + .../migraphx/migraphx_provider_factory.cc | 2 + .../python/onnxruntime_pybind_state.cc | 11 + onnxruntime/test/util/default_providers.cc | 1 + 8 files changed, 223 insertions(+), 106 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 370d2e13e6ff7..6d7794cfb1df4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -615,6 +615,7 @@ typedef struct OrtTensorRTProviderOptions { typedef struct OrtMIGraphXProviderOptions { int device_id; // hip device id. int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 2cb0231b80caf..a4e11f705e250 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -108,32 +108,97 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { InitProviderOrtApi(); + get_flags_from_session_info(info); + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); + get_flags_from_env(); +} + +MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { +} + +void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(info_.device_id)); t_ = migraphx::target(info.target_device.c_str()); + // Quantization + fp16_enable_ = info.fp16_enable; + fp8_enable_ = info.fp8_enable; + int8_enable_ = info.int8_enable; + + if (int8_enable_ and fp8_enable_) { + LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + + if (int8_enable_ xor fp8_enable_) { + int8_calibration_cache_name_ = info.int8_calibration_table_name; + int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; + } + + if (int8_enable_ or fp8_enable_) { + int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path); + } + } + + // Save/load migraphx compiled models + save_compiled_model_ = info.save_compiled_model; + save_compiled_path_ = info.save_model_file; + load_compiled_model_ = info.load_compiled_model; + load_compiled_path_ = info.load_model_file; + + exhaustive_tune_ = info.exhaustive_tune; + + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider Session Options:"; + print_migraphx_ep_flags(); +} + +void MIGraphXExecutionProvider::get_flags_from_env() { + LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; // whether fp16 is enable const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; + } + + // whether fp8 quantization is enabled + const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); + if (!fp8_enable_env.empty()) { + fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; } // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; } - if (int8_enable_) { + if (int8_enable_ and fp8_enable_) { + LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + + if (int8_enable_ || fp8_enable_) { const std::string int8_calibration_cache_name_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); if (!int8_calibration_cache_name_env.empty()) { int8_calibration_cache_name_ = int8_calibration_cache_name_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; } const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); if (!cache_path.empty()) { calibration_cache_path_ = cache_path; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; } const std::string int8_use_native_migraphx_calibration_table_env = @@ -141,19 +206,21 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv if (!int8_use_native_migraphx_calibration_table_env.empty()) { int8_use_native_migraphx_calibration_table_ = (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " + << int8_use_native_migraphx_calibration_table_; } } - if (int8_enable_) { + if (int8_enable_ or fp8_enable_) { int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } // Load INT8 calibration table std::unordered_map dynamic_range_map; - if (int8_enable_ && int8_calibration_cache_available_) { + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path); } } @@ -161,55 +228,56 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); if (!save_comp_model_env.empty()) { save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; } const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { save_compiled_path_ = save_model_path_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; } const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); if (!load_comp_model_env.empty()) { load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; } const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); if (load_compiled_model_ && !load_model_path_env.empty()) { load_compiled_path_ = load_model_path_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; } // dump unsupported ops const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); if (!dump_model_ops_env.empty()) { dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; } // Allow for exhaustive tune during compile const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); if (!exhaustive_tune_env.empty()) { exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; } - - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); - - LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " - << "device_id: " << info_.device_id - << ", migraphx_fp16_enable: " << fp16_enable_ - << ", migraphx_int8_enable: " << int8_enable_ - << ", migraphx_int8_enable: " << int8_enable_ - << ", dump_model_ops: " << dump_model_ops_ - << ", exhaustive_tune: " << exhaustive_tune_ - << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ - << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ - << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << ", migraphx_save_compiled_model: " << save_compiled_model_ - << ", migraphx_save_compiled_model_path: " << save_compiled_path_ - << ", migraphx_load_compiled_model: " << load_compiled_model_ - << ", migraphx_load_compiled_model_path: " << load_compiled_path_; } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { +void MIGraphXExecutionProvider::print_migraphx_ep_flags() { + LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id + << "\n migraphx_fp16_enable: " << fp16_enable_ + << "\n migraphx_fp8_enable: " << fp8_enable_ + << "\n migraphx_int8_enable: " << int8_enable_ + << "\n dump_model_ops: " << dump_model_ops_ + << "\n exhaustive_tune: " << exhaustive_tune_ + << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ + << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ + << "\n migraphx_save_compiled_model: " << save_compiled_model_ + << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ + << "\n migraphx_load_compiled_model: " << load_compiled_model_ + << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; } AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, @@ -274,11 +342,17 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: @@ -303,6 +377,21 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: mgx_type = migraphx_shape_double_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: + mgx_type = migraphx_shape_fp8e4m3fnuz_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + mgx_type = migraphx_shape_fp8e4m3fn_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: + mgx_type = migraphx_shape_fp8e5m2_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: + mgx_type = migraphx_shape_fp8e5m2fnuz_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + mgx_type = migraphx_shape_int8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -315,6 +404,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + mgx_type = migraphx_shape_uint8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; @@ -1063,7 +1155,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v if (dump_model_ops_) { LOGS_DEFAULT(INFO) << "============= Unsupported nodes ===================="; for (auto idx : unsupported_nodes) { - LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl; + LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType(); } LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************"; } @@ -1136,9 +1228,9 @@ bool get_input_output_names(const GraphViewer& graph, bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { try { if (load_enable) { - LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path; + LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(INFO) << "load model : Success"; + LOGS_DEFAULT(WARNING) << "load model : Success"; return true; } else { return false; @@ -1151,14 +1243,73 @@ bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::stri void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { if (save_enable) { - LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin" << std::endl; + LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(INFO) << "Model Save: Complete" << std::endl; + LOGS_DEFAULT(WARNING) << "Model Save: Complete"; + } +} + +// Order matters here especially if the program uses mixed quantization +// Calibrate on full precision for int8/fp8 and then quantize down to fp16 +void calibrate_and_quantize(migraphx::program& prog, + const migraphx::target& t, + const migraphx::program_parameters quant_params, + bool fp16_enable, + bool int8_enable, + bool fp8_enable, + bool int8_calibration_cache_available, + std::unordered_map& dynamic_range_map) { + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { + LOGS_DEFAULT(WARNING) << "Quantizing input program"; + + auto param_shapes = prog.get_parameter_shapes(); + + // Add all calibration data read in from int8 table + for (auto& [cal_key, cal_val] : dynamic_range_map) { + auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); + } + + // perform static quantization on the programs + if (int8_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to int8"; + migraphx::quantize_int8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + migraphx::quantize_int8(prog, t, quant_opts); + LOGS_DEFAULT(WARNING) << "Quantizing int8: Complete"; + } else if (fp8_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp8"; + migraphx::quantize_fp8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_fp8(prog, t, quant_opts); + LOGS_DEFAULT(WARNING) << "Quantizing fp8: Complete"; + } + } + + if (fp16_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp16"; + migraphx::quantize_fp16(prog); + LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } } +void compile_program(migraphx::program& prog, + const migraphx::target& t, + bool exhaustive_tune) { + LOGS_DEFAULT(WARNING) << "Model Compile: Begin"; + migraphx::compile_options co; + co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune); + prog.compile(t, co); + LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; +} + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1199,44 +1350,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); + migraphx::program_parameters quant_params; - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable_ && int8_calibration_cache_available_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; - - auto param_shapes = prog.get_parameter_shapes(); - - // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : dynamic_range_map_) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); - } - quant_opts.add_calibration_data(quant_params); - - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); - - // perform static quantization on the programs - migraphx::quantize_int8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; - } - - if (fp16_enable_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16" << std::endl; - migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete" << std::endl; - } - - migraphx::compile_options co; - co.set_fast_math(false); - co.set_exhaustive_tune_flag(exhaustive_tune_); - LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; - prog.compile(t_, co); - LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl; - + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); + compile_program(prog, t_, exhaustive_tune_); save_compiled_model(prog, save_compiled_model_, save_compiled_path_); } @@ -1258,7 +1376,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; @@ -1283,6 +1401,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; @@ -1291,7 +1410,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; if (no_input_shape) { - LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again" << std::endl; + LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1303,7 +1422,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model" << std::endl; + LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1345,15 +1464,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& #endif #endif prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); + migraphx::program_parameters quant_params; - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable && int8_calibration_cache_available) { - LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl; - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; - + if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { auto param_shapes = prog.get_parameter_shapes(); - // Add input parameter data and the values they're set to for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { @@ -1372,34 +1486,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } } - - // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : map_dynamic_range) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); - } - quant_opts.add_calibration_data(quant_params); - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); - - // perform static quantization on the programs - migraphx::quantize_int8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl; } - - if (fp16_enable) { - LOGS_DEFAULT(INFO) << "Quantize fp16: Begin" << std::endl; - migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantize fp16: Completed" << std::endl; - } - - LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; - migraphx::compile_options co; - co.set_fast_math(false); - co.set_exhaustive_tune_flag(exhaustive_tune_); - prog.compile(t, co); - + calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + fp8_enable, int8_calibration_cache_available, map_dynamic_range); + compile_program(prog, t, exhaustive_tune_); save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); } @@ -1414,7 +1504,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name << std::endl; + LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1428,7 +1518,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data " << std::endl; + LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index eb39635dae5b5..6f3458a26c615 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -17,15 +17,16 @@ namespace onnxruntime { namespace migraphx_env_vars { static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH"; +static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH"; +static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; }; // namespace migraphx_env_vars @@ -43,6 +44,7 @@ struct MIGraphXFuncState { std::mutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -60,6 +62,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider { explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); ~MIGraphXExecutionProvider(); + void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); + void get_flags_from_env(); + void print_migraphx_ep_flags(); + Status Sync() const override; Status OnRunStart(const onnxruntime::RunOptions& run_options) override; @@ -92,6 +98,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; + bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 11270f2e64b82..cf21d791cfe6b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -21,6 +21,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kFp8Enable = "migx_fp8_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; @@ -82,6 +83,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) @@ -100,6 +102,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, @@ -118,6 +121,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 3dbde11ddc4a9..a598052c5f025 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -42,6 +42,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool fp8_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 545b02a345830..519b8c7870092 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -82,6 +82,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.fp8_enable = options.migraphx_fp8_enable; info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; @@ -109,6 +110,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_fp8_enable = internal_options.fp8_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9e358828f7af8..4b7c0eccc4354 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -855,6 +855,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr", @@ -880,6 +881,16 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_fp8_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp8_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp8_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "migraphx_int8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true; diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 5ddf0eaaabb7b..268d10a1c4b5d 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -77,6 +77,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr",