diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 370d2e13e6ff7..6d7794cfb1df4 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -615,6 +615,7 @@ typedef struct OrtTensorRTProviderOptions { typedef struct OrtMIGraphXProviderOptions { int device_id; // hip device id. int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_fp8_enable; // MIGraphX FP8 precision. Default 0 = false, nonzero = true int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 2cb0231b80caf..a4e11f705e250 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -108,32 +108,97 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { InitProviderOrtApi(); + get_flags_from_session_info(info); + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); + get_flags_from_env(); +} + +MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { +} + +void MIGraphXExecutionProvider::get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info) { // Set GPU device to be used HIP_CALL_THROW(hipSetDevice(info_.device_id)); t_ = migraphx::target(info.target_device.c_str()); + // Quantization + fp16_enable_ = info.fp16_enable; + fp8_enable_ = info.fp8_enable; + int8_enable_ = info.int8_enable; + + if (int8_enable_ and fp8_enable_) { + LOGS_DEFAULT(FATAL) << "MIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + + if (int8_enable_ xor fp8_enable_) { + int8_calibration_cache_name_ = info.int8_calibration_table_name; + int8_use_native_migraphx_calibration_table_ = info.int8_use_native_calibration_table; + } + + if (int8_enable_ or fp8_enable_) { + int8_calibration_cache_available_ = !info.int8_calibration_table_name.empty(); + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Session Failed to read INT8 calibration table " + calibration_cache_path); + } + } + + // Save/load migraphx compiled models + save_compiled_model_ = info.save_compiled_model; + save_compiled_path_ = info.save_model_file; + load_compiled_model_ = info.load_compiled_model; + load_compiled_path_ = info.load_model_file; + + exhaustive_tune_ = info.exhaustive_tune; + + LOGS_DEFAULT(WARNING) << "[MIGraphX EP] MIGraphX provider Session Options:"; + print_migraphx_ep_flags(); +} + +void MIGraphXExecutionProvider::get_flags_from_env() { + LOGS_DEFAULT(WARNING) << "\n[MIGraphX EP] MIGraphX ENV Override Variables Set:"; // whether fp16 is enable const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP16_ENABLE: " << fp16_enable_; + } + + // whether fp8 quantization is enabled + const std::string fp8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP8Enable); + if (!fp8_enable_env.empty()) { + fp8_enable_ = (std::stoi(fp8_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_FP8_ENABLE: " << fp8_enable_; } // whether int8 is enabled const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); if (!int8_enable_env.empty()) { int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_ENABLE: " << int8_enable_; } - if (int8_enable_) { + if (int8_enable_ and fp8_enable_) { + LOGS_DEFAULT(FATAL) << "\nMIGraphX: FP8 and INT8 Quantization Mutually exclusive. Ignoring both Quantization flags"; + } + + if (int8_enable_ || fp8_enable_) { const std::string int8_calibration_cache_name_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); if (!int8_calibration_cache_name_env.empty()) { int8_calibration_cache_name_ = int8_calibration_cache_name_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CALIBRATION_TABLE_NAME: " << int8_calibration_cache_name_; } const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); if (!cache_path.empty()) { calibration_cache_path_ = cache_path; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_CACHE_PATH: " << calibration_cache_path_; } const std::string int8_use_native_migraphx_calibration_table_env = @@ -141,19 +206,21 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv if (!int8_use_native_migraphx_calibration_table_env.empty()) { int8_use_native_migraphx_calibration_table_ = (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE: " + << int8_use_native_migraphx_calibration_table_; } } - if (int8_enable_) { + if (int8_enable_ or fp8_enable_) { int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } // Load INT8 calibration table std::unordered_map dynamic_range_map; - if (int8_enable_ && int8_calibration_cache_available_) { + if ((int8_enable_ || fp8_enable_) && int8_calibration_cache_available_) { const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { - throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + throw std::runtime_error("ENV Failed to read calibration table " + calibration_cache_path); } } @@ -161,55 +228,56 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv const std::string save_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSaveCompiledModel); if (!save_comp_model_env.empty()) { save_compiled_model_ = (std::stoi(save_comp_model_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_MODEL: " << save_compiled_model_; } const std::string save_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kSavedModelPath); - if (save_compiled_model_ && !save_model_path_env.empty()) { save_compiled_path_ = save_model_path_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_SAVE_COMPILED_PATH: " << save_compiled_path_; } const std::string load_comp_model_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadCompiledModel); if (!load_comp_model_env.empty()) { load_compiled_model_ = (std::stoi(load_comp_model_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_MODEL: " << load_compiled_model_; } const std::string load_model_path_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kLoadModelPath); if (load_compiled_model_ && !load_model_path_env.empty()) { load_compiled_path_ = load_model_path_env; + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_LOAD_COMPILED_PATH: " << load_compiled_path_; } // dump unsupported ops const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); if (!dump_model_ops_env.empty()) { dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_DUMP_MODEL_OPS: " << dump_model_ops_; } // Allow for exhaustive tune during compile const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); if (!exhaustive_tune_env.empty()) { exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); + LOGS_DEFAULT(WARNING) << "\nORT_MIGRAPHX_EXHAUSTIVE_TUNE_OPS: " << exhaustive_tune_; } - - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); - - LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " - << "device_id: " << info_.device_id - << ", migraphx_fp16_enable: " << fp16_enable_ - << ", migraphx_int8_enable: " << int8_enable_ - << ", migraphx_int8_enable: " << int8_enable_ - << ", dump_model_ops: " << dump_model_ops_ - << ", exhaustive_tune: " << exhaustive_tune_ - << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ - << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ - << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ - << ", migraphx_save_compiled_model: " << save_compiled_model_ - << ", migraphx_save_compiled_model_path: " << save_compiled_path_ - << ", migraphx_load_compiled_model: " << load_compiled_model_ - << ", migraphx_load_compiled_model_path: " << load_compiled_path_; } -MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { +void MIGraphXExecutionProvider::print_migraphx_ep_flags() { + LOGS_DEFAULT(WARNING) << "\n device_id: " << info_.device_id + << "\n migraphx_fp16_enable: " << fp16_enable_ + << "\n migraphx_fp8_enable: " << fp8_enable_ + << "\n migraphx_int8_enable: " << int8_enable_ + << "\n dump_model_ops: " << dump_model_ops_ + << "\n exhaustive_tune: " << exhaustive_tune_ + << "\n migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << "\n int8_calibration_cache_available: " << int8_calibration_cache_available_ + << "\n use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ + << "\n migraphx_save_compiled_model: " << save_compiled_model_ + << "\n migraphx_save_compiled_model_path: " << save_compiled_path_ + << "\n migraphx_load_compiled_model: " << load_compiled_model_ + << "\n migraphx_load_compiled_model_path: " << load_compiled_path_; } AllocatorPtr MIGraphXExecutionProvider::CreateMIGraphXAllocator(OrtDevice::DeviceId device_id, @@ -274,11 +342,17 @@ static bool IsTypeSupported(const NodeArg* node_arg) { switch (type_proto->tensor_type().elem_type()) { case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT4: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: @@ -303,6 +377,21 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: mgx_type = migraphx_shape_double_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ: + mgx_type = migraphx_shape_fp8e4m3fnuz_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN: + mgx_type = migraphx_shape_fp8e4m3fn_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2: + mgx_type = migraphx_shape_fp8e5m2_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ: + mgx_type = migraphx_shape_fp8e5m2fnuz_type; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4: + mgx_type = migraphx_shape_int8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; @@ -315,6 +404,9 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4: + mgx_type = migraphx_shape_uint8_type; + break; case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; @@ -1063,7 +1155,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v if (dump_model_ops_) { LOGS_DEFAULT(INFO) << "============= Unsupported nodes ===================="; for (auto idx : unsupported_nodes) { - LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType() << std::endl; + LOGS_DEFAULT(INFO) << graph_viewer.GetNode(idx)->OpType(); } LOGS_DEFAULT(INFO) << "************* Unsupported nodes ********************"; } @@ -1136,9 +1228,9 @@ bool get_input_output_names(const GraphViewer& graph, bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::string path) { try { if (load_enable) { - LOGS_DEFAULT(INFO) << "Attempting to load model at:" << path; + LOGS_DEFAULT(WARNING) << "Attempting to load model at:" << path; prog = migraphx::load(path.c_str()); - LOGS_DEFAULT(INFO) << "load model : Success"; + LOGS_DEFAULT(WARNING) << "load model : Success"; return true; } else { return false; @@ -1151,14 +1243,73 @@ bool load_precompiled_model(migraphx::program& prog, bool load_enable, std::stri void save_compiled_model(migraphx::program& prog, bool save_enable, std::string out_path) { if (save_enable) { - LOGS_DEFAULT(INFO) << "Model Save at " << out_path << ": Begin" << std::endl; + LOGS_DEFAULT(WARNING) << "Model Save at " << out_path << ": Begin"; migraphx::file_options fo; fo.set_file_format("msgpack"); migraphx::save(prog, out_path.c_str(), fo); - LOGS_DEFAULT(INFO) << "Model Save: Complete" << std::endl; + LOGS_DEFAULT(WARNING) << "Model Save: Complete"; + } +} + +// Order matters here especially if the program uses mixed quantization +// Calibrate on full precision for int8/fp8 and then quantize down to fp16 +void calibrate_and_quantize(migraphx::program& prog, + const migraphx::target& t, + const migraphx::program_parameters quant_params, + bool fp16_enable, + bool int8_enable, + bool fp8_enable, + bool int8_calibration_cache_available, + std::unordered_map& dynamic_range_map) { + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if ((int8_enable xor fp8_enable) && int8_calibration_cache_available) { + LOGS_DEFAULT(WARNING) << "Quantizing input program"; + + auto param_shapes = prog.get_parameter_shapes(); + + // Add all calibration data read in from int8 table + for (auto& [cal_key, cal_val] : dynamic_range_map) { + auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); + quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); + } + + // perform static quantization on the programs + if (int8_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to int8"; + migraphx::quantize_int8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + // specify thing we want to int8 quantize + quant_opts.add_op_name("convolution"); + quant_opts.add_op_name("dot"); + migraphx::quantize_int8(prog, t, quant_opts); + LOGS_DEFAULT(WARNING) << "Quantizing int8: Complete"; + } else if (fp8_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp8"; + migraphx::quantize_fp8_options quant_opts; + quant_opts.add_calibration_data(quant_params); + migraphx::quantize_fp8(prog, t, quant_opts); + LOGS_DEFAULT(WARNING) << "Quantizing fp8: Complete"; + } + } + + if (fp16_enable) { + LOGS_DEFAULT(WARNING) << "Quantizing input program to fp16"; + migraphx::quantize_fp16(prog); + LOGS_DEFAULT(WARNING) << "Quantizing fp16: Complete"; } } +void compile_program(migraphx::program& prog, + const migraphx::target& t, + bool exhaustive_tune) { + LOGS_DEFAULT(WARNING) << "Model Compile: Begin"; + migraphx::compile_options co; + co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune); + prog.compile(t, co); + LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; +} + Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { migraphx::onnx_options options; @@ -1199,44 +1350,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) { LOGS_DEFAULT(INFO) << "No input shapes detected quantizing model"; prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); + migraphx::program_parameters quant_params; - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable_ && int8_calibration_cache_available_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to int8" << std::endl; - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; - - auto param_shapes = prog.get_parameter_shapes(); - - // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : dynamic_range_map_) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); - } - quant_opts.add_calibration_data(quant_params); - - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); - - // perform static quantization on the programs - migraphx::quantize_int8(prog, t_, quant_opts); - LOGS_DEFAULT(INFO) << "Quantizing input program to int8: Complete" << std::endl; - } - - if (fp16_enable_) { - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16" << std::endl; - migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantizing input program to fp16: Complete" << std::endl; - } - - migraphx::compile_options co; - co.set_fast_math(false); - co.set_exhaustive_tune_flag(exhaustive_tune_); - LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; - prog.compile(t_, co); - LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl; - + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, int8_enable_, + fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); + compile_program(prog, t_, exhaustive_tune_); save_compiled_model(prog, save_compiled_model_, save_compiled_path_); } @@ -1258,7 +1376,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + map_no_input_shape_[context->node_name], fp16_enable_, fp8_enable_, int8_enable_, int8_calibration_cache_available_, dynamic_range_map_, save_compiled_model_, save_compiled_path_, load_compiled_model_, load_compiled_path_, dump_model_ops_}; @@ -1283,6 +1401,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool fp8_enable = mgx_state->fp8_enable; bool int8_enable = mgx_state->int8_enable; bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; @@ -1291,7 +1410,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& bool input_shape_match = true; migraphx::program_parameter_shapes param_shapes; if (no_input_shape) { - LOGS_DEFAULT(VERBOSE) << "Missing input shape setting input parameters again" << std::endl; + LOGS_DEFAULT(INFO) << "Missing input shape setting input parameters again"; for (auto& it : map_input_name_index) { auto& name = it.first; auto& index = it.second; @@ -1303,7 +1422,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_shape_match = false; } } else { - LOGS_DEFAULT(VERBOSE) << "Assigning inputs, and parameters from compiled model" << std::endl; + LOGS_DEFAULT(INFO) << "Assigning inputs, and parameters from compiled model"; param_shapes = prog.get_parameter_shapes(); auto prog_output_shapes = prog.get_output_shapes(); @@ -1345,15 +1464,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& #endif #endif prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); + migraphx::program_parameters quant_params; - // Read in the calibration data and map it to an migraphx paramater map for the calibration ops - if (int8_enable && int8_calibration_cache_available) { - LOGS_DEFAULT(INFO) << "Quantize Int8: Begin" << std::endl; - migraphx::quantize_int8_options quant_opts; - migraphx::program_parameters quant_params; - + if ((int8_enable xor fp8_enable) and int8_calibration_cache_available) { auto param_shapes = prog.get_parameter_shapes(); - // Add input parameter data and the values they're set to for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { @@ -1372,34 +1486,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& quant_params.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } } - - // Add all calibration data read in from int8 table - for (auto& [cal_key, cal_val] : map_dynamic_range) { - auto cal_val_shape = migraphx::shape(migraphx_shape_float_type); - quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast(std::move(&cal_val)))); - } - quant_opts.add_calibration_data(quant_params); - // specify thing we want to int8 quantize - quant_opts.add_op_name("convolution"); - quant_opts.add_op_name("dot"); - - // perform static quantization on the programs - migraphx::quantize_int8(prog, t, quant_opts); - LOGS_DEFAULT(INFO) << "Quantize Int8: Completed" << std::endl; } - - if (fp16_enable) { - LOGS_DEFAULT(INFO) << "Quantize fp16: Begin" << std::endl; - migraphx::quantize_fp16(prog); - LOGS_DEFAULT(INFO) << "Quantize fp16: Completed" << std::endl; - } - - LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; - migraphx::compile_options co; - co.set_fast_math(false); - co.set_exhaustive_tune_flag(exhaustive_tune_); - prog.compile(t, co); - + calibrate_and_quantize(prog, t, quant_params, fp16_enable, int8_enable, + fp8_enable, int8_calibration_cache_available, map_dynamic_range); + compile_program(prog, t, exhaustive_tune_); save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); } @@ -1414,7 +1504,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (param_shapes.size() > 0) { for (auto&& name : param_shapes.names()) { if (map_input_name_index.count(name) > 0) { - LOGS_DEFAULT(INFO) << "Setting parameters for:" << name << std::endl; + LOGS_DEFAULT(INFO) << "Setting parameters for:" << name; auto input_tensor = ctx.GetInput(map_input_name_index[name]); auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shape = tensor_info.GetShape(); @@ -1428,7 +1518,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - LOGS_DEFAULT(INFO) << "Writing Raw tensor data " << std::endl; + LOGS_DEFAULT(INFO) << "Writing Raw tensor data "; m.add(name, migraphx::argument(param_shapes[name], const_cast(input_tensor.GetTensorRawData()))); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index eb39635dae5b5..6f3458a26c615 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -17,15 +17,16 @@ namespace onnxruntime { namespace migraphx_env_vars { static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kFP8Enable[] = "ORT_MIGRAPHX_FP8_ENABLE"; static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; -static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH"; +static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILED_PATH"; static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; -static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH"; +static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILED_PATH"; static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; }; // namespace migraphx_env_vars @@ -43,6 +44,7 @@ struct MIGraphXFuncState { std::mutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool fp8_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; std::unordered_map dynamic_range_map; @@ -60,6 +62,10 @@ class MIGraphXExecutionProvider : public IExecutionProvider { explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); ~MIGraphXExecutionProvider(); + void get_flags_from_session_info(const MIGraphXExecutionProviderInfo& info); + void get_flags_from_env(); + void print_migraphx_ep_flags(); + Status Sync() const override; Status OnRunStart(const onnxruntime::RunOptions& run_options) override; @@ -92,6 +98,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; + bool fp8_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; bool int8_calibration_cache_available_ = false; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 11270f2e64b82..cf21d791cfe6b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -21,6 +21,7 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kFp8Enable = "migx_fp8_enable"; constexpr const char* kInt8Enable = "migx_int8_enable"; constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; @@ -82,6 +83,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kFp8Enable, info.fp8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) @@ -100,6 +102,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, @@ -118,6 +121,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kFp8Enable, MakeStringWithClassicLocale(info.migraphx_fp8_enable)}, {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 3dbde11ddc4a9..a598052c5f025 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -42,6 +42,7 @@ struct MIGraphXExecutionProviderInfo { std::string target_device; OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; + bool fp8_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 545b02a345830..519b8c7870092 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -82,6 +82,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.fp8_enable = options.migraphx_fp8_enable; info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; @@ -109,6 +110,7 @@ struct MIGraphX_Provider : Provider { auto& migx_options = *reinterpret_cast(provider_options); migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_fp8_enable = internal_options.fp8_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9e358828f7af8..4b7c0eccc4354 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -855,6 +855,7 @@ std::unique_ptr CreateExecutionProviderInstance( 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr", @@ -880,6 +881,16 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" " 'True' or 'False'. Default value is 'False'.\n"); } + } else if (option.first == "migraphx_fp8_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp8_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp8_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_fp8_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else if (option.first == "migraphx_int8_enable") { if (option.second == "True" || option.second == "true") { params.migraphx_int8_enable = true; diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 5ddf0eaaabb7b..268d10a1c4b5d 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -77,6 +77,7 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 0, 0, 0, + 0, nullptr, 1, "./compiled_model.mxr",