From 4ef126ba4e8b86c05857cc3a2f8fda0207458900 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 02:58:50 +0000 Subject: [PATCH 1/9] Update LLVM GPU backend for ROCm support, make rocm llvm compiler as a default option --- .../xla/service/gpu/llvm_gpu_backend/BUILD | 1 + .../gpu/llvm_gpu_backend/gpu_backend_lib.cc | 424 ++++++++++++++---- 2 files changed, 340 insertions(+), 85 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 56e4d551bd456a..6ea5b49ca13381 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -65,6 +65,7 @@ cc_library( "@llvm-project//llvm:Target", ] + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", + "//tensorflow/tsl/platform:rocm_rocdl_path", "@llvm-project//llvm:AMDGPUCodeGen", ]), ) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 8699db2c643ce7..b93482015882e4 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include +#include #include #include #include @@ -60,7 +61,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/tsl/platform/cuda_libdevice_path.h" + #include "tensorflow/tsl/platform/env.h" #include "tensorflow/tsl/platform/logging.h" #include "tensorflow/tsl/platform/path.h" @@ -68,10 +69,16 @@ limitations under the License. #include "tensorflow/tsl/profiler/lib/traceme.h" #include "tensorflow/tsl/util/env_var.h" + #if !defined(PLATFORM_GOOGLE) && TENSORFLOW_USE_ROCM +#include "tensorflow/tsl/platform/rocm_rocdl_path.h" #include "rocm/rocm_config.h" +#else +#include "tensorflow/tsl/platform/cuda_libdevice_path.h" #endif +#define TENSORFLOW_HSACO_USE_ROCM_LLVM + namespace xla { namespace gpu { namespace { @@ -227,6 +234,82 @@ bool CouldNeedDeviceBitcode(const llvm::Module& module) { // Links the module with a vector of path to bitcode modules. // The caller must guarantee that the paths exist. +#ifdef TENSORFLOW_HSACO_USE_ROCM_LLVM +Status LinkWithBitcodeVector( + llvm::Module* module, const std::vector& bitcode_path_vector, + const std::string& ir_path, const std::string& linked_ir_path, + const std::string& optimized_ir_path) { + std::error_code ec; + std::string error_message; + + for (auto& bitcode_path : bitcode_path_vector) { + if (!tsl::Env::Default()->FileExists(bitcode_path).ok()) { + LOG(ERROR) << "bitcode module is required by this HLO module but was " + "not found at " + << bitcode_path; + return xla::InternalError("bitcode module not found at %s", bitcode_path); + } + } + + // Dump LLVM IR. + std::unique_ptr ir_fs( + new llvm::raw_fd_ostream(ir_path, ec, llvm::sys::fs::OF_None)); + module->print(*ir_fs, nullptr); + ir_fs->flush(); + + // Locate llvm-link. + std::string llvmlink_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + auto llvmlink_program = + llvm::sys::findProgramByName("llvm-link", {llvmlink_path}); + if (!llvmlink_program) { + return xla::InternalError("unable to find llvm-link in PATH: %s", + llvmlink_program.getError().message()); + } + // Setup llvm-link arguments. + std::vector llvmlink_args{ + llvm_ir::AsStringRef("llvm-link"), + llvm_ir::AsStringRef("-o"), + llvm_ir::AsStringRef(linked_ir_path), + }; + + llvmlink_args.push_back(llvm_ir::AsStringRef(ir_path)); + for (auto& bitcode_path : bitcode_path_vector) { + llvmlink_args.push_back(llvm_ir::AsStringRef(bitcode_path)); + } + + int llvmlink_result = llvm::sys::ExecuteAndWait( + *llvmlink_program, llvm_ir::AsArrayRef(llvmlink_args), std::nullopt, {}, + 0, 0, &error_message); + + if (llvmlink_result) { + return xla::InternalError("llvm-link execute fail: %s", error_message); + } + + // Locate opt. + std::string opt_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + auto opt_program = llvm::sys::findProgramByName("opt", {opt_path}); + if (!opt_program) { + return xla::InternalError("unable to find opt in PATH: %s", + opt_program.getError().message()); + } + std::vector opt_args{ + llvm_ir::AsStringRef("opt"), + llvm_ir::AsStringRef("-O3"), + llvm_ir::AsStringRef("-o"), + llvm_ir::AsStringRef(optimized_ir_path), + llvm_ir::AsStringRef(linked_ir_path), + }; + + int opt_result = + llvm::sys::ExecuteAndWait(*opt_program, llvm_ir::AsArrayRef(opt_args), + std::nullopt, {}, 0, 0, &error_message); + + if (opt_result) { + return xla::InternalError("opt execute fail: %s", error_message); + } + return OkStatus(); +} +#else Status LinkWithBitcodeVector( llvm::Module* module, const std::vector& bitcode_path_vector) { llvm::Linker linker(*module); @@ -258,6 +341,9 @@ Status LinkWithBitcodeVector( return OkStatus(); } +#endif + +#ifdef GOOGLE_CUDA // Links libdevice into the given module if the module needs libdevice. Status LinkLibdeviceIfNecessary(llvm::Module* module, const std::string& libdevice_dir_path) { @@ -282,7 +368,10 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, Status NVPTXTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version, const HloModuleConfig& hlo_module_config, - const std::string& device_bitcode_dir_path) { + const std::string& device_bitcode_dir_path, + const std::string& ir_path, + const std::string& linked_ir_path, + const std::string& optimized_ir_path) { // Link the input module with libdevice, to pull in implementations of some // builtins. TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, device_bitcode_dir_path)); @@ -315,9 +404,15 @@ std::unique_ptr NVPTXGetTargetMachine( return GetTargetMachine(target_triple, GetSmName(compute_capability), hlo_module_config, ptx_ver); } - +#endif +#ifdef TENSORFLOW_HSACO_USE_ROCM_LLVM +using TargetModuleLinker = std::function; +#else using TargetModuleLinker = std::function; +#endif void DumpModule(const std::string output_filename, const llvm::Module* module) { std::error_code ec; @@ -372,7 +467,21 @@ auto DumpCallbackForModule(std::string module_identifier, DumpModule(tsl::io::JoinPath(outputs_dir, basename), module); }; } - +#ifdef TENSORFLOW_HSACO_USE_ROCM_LLVM +Status LinkAndOptimizeModule(llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, + const std::string& device_bitcode_dir_path, + TargetModuleLinker module_linker, + llvm::Triple default_target_triple, + llvm::TargetMachine* target_machine, + int inline_threshold, const std::string& ir_path, + const std::string& linked_ir_path, + const std::string& optimized_ir_path) { + return module_linker(module, gpu_version, hlo_module_config, + device_bitcode_dir_path, ir_path, linked_ir_path, + optimized_ir_path); +} +#else Status LinkAndOptimizeModule(llvm::Module* module, GpuVersion gpu_version, const HloModuleConfig& hlo_module_config, const std::string& device_bitcode_dir_path, @@ -464,6 +573,7 @@ Status LinkAndOptimizeModule(llvm::Module* module, GpuVersion gpu_version, return OkStatus(); } +#endif // One-time module initializer. // Must be called only once -- DO NOT CALL DIRECTLY. void NVPTXBackendInit(const HloModuleConfig& hlo_module_config) { @@ -517,7 +627,7 @@ void NVPTXBackendInit(const HloModuleConfig& hlo_module_config) { } // namespace namespace nvptx { - +#ifdef GOOGLE_CUDA std::string CantFindCudaMessage(absl::string_view msg, absl::string_view xla_gpu_cuda_data_dir) { return absl::StrCat( @@ -629,10 +739,34 @@ StatusOr CompileToPtx( } return ptx; } - +#endif // GOOGLE_CUDA } // namespace nvptx namespace { +static std::string hsaco_cache_dir_; +static std::mutex hsaco_cache_mutex_; +static absl::flat_hash_map> hsaco_cache_; + +static void InitHsacoCacheDir() { + static absl::once_flag init_once; + absl::call_once(init_once, [] { + auto env = tsl::Env::Default(); + tsl::ReadStringFromEnvVar("TF_XLA_HSACO_CACHE_DIR", "/tmp", + &hsaco_cache_dir_); + if (hsaco_cache_dir_.empty()) { + LOG(INFO) << "Will not cache XLA HSACOs. " + << "This line is logged at most " + << "once for the lifetime of the process."; + } else { + if (!env->IsDirectory(hsaco_cache_dir_).ok()) { + env->CreateDir(hsaco_cache_dir_); + } + LOG(INFO) << "Cache XLA HSACOs in " << hsaco_cache_dir_ << ". " + << "This line is logged at most " + << "once for the lifetime of the process."; + } + }); +} // Gets the ROCm-Device-Libs filenames for a particular AMDGPU version. std::vector GetROCDLPaths(std::string gcn_arch_name, @@ -642,7 +776,8 @@ std::vector GetROCDLPaths(std::string gcn_arch_name, new std::vector( {"opencl.bc", "ocml.bc", "ockl.bc", "oclc_finite_only_off.bc", "oclc_daz_opt_off.bc", "oclc_correctly_rounded_sqrt_on.bc", - "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc"}); + "oclc_unsafe_math_off.bc", "oclc_wavefrontsize64_on.bc", "hip.bc", + "oclc_abi_version_500.bc"}); // Construct full path to ROCDL bitcode libraries. std::vector result; @@ -663,63 +798,99 @@ std::vector GetROCDLPaths(std::string gcn_arch_name, return result; } -struct HsacoCacheEntry { - uint64_t hash; - std::string ir; - std::string gfx; - std::vector hsaco; -}; - -struct HsacoCache { - protected: - std::vector cache; - std::mutex m_mutex; - int request_count = 0; - int hit_count = 0; - - public: - static bool Find(const std::string& ir, uint64_t& hash, - const std::string& gfx, std::vector& hsaco); - static void Add(const std::string& ir, uint64_t hash, const std::string& gfx, - const std::vector& hsaco); -}; - -static HsacoCache g_hsacoCache; - -bool HsacoCache::Find(const std::string& ir, uint64_t& hash, - const std::string& gfx, std::vector& hsaco) { - std::lock_guard lg(g_hsacoCache.m_mutex); - hash = std::hash{}(ir); - bool hit = false; - for (auto& x : g_hsacoCache.cache) { - if (x.hash != hash) continue; - if (x.gfx != gfx) continue; - if (x.ir != ir) continue; - hsaco = x.hsaco; - hit = true; - break; +Status ReadHsaco(std::string hsaco_path, std::vector& hsaco) { + std::lock_guard lg(hsaco_cache_mutex_); + auto it = hsaco_cache_.find(hsaco_path); + if (it != hsaco_cache_.end()) { + VLOG(1) << "Hsaco cache hit in memory " << hsaco_path; + hsaco = it->second; + return OkStatus(); } - g_hsacoCache.request_count++; - if (hit) g_hsacoCache.hit_count++; - if (!(g_hsacoCache.request_count % 50)) - VLOG(1) << "HSACO cache: " << g_hsacoCache.request_count << " requests, " - << g_hsacoCache.hit_count << " hits"; - return hit; -} - -void HsacoCache::Add(const std::string& ir, uint64_t hash, - const std::string& gfx, - const std::vector& hsaco) { - std::lock_guard lg(g_hsacoCache.m_mutex); - g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1); - g_hsacoCache.cache.back().ir = ir; - g_hsacoCache.cache.back().hash = hash; - g_hsacoCache.cache.back().gfx = gfx; - g_hsacoCache.cache.back().hsaco = hsaco; + if (tsl::Env::Default()->FileExists(hsaco_path).ok()) { + VLOG(1) << "Hsaco cache hit in file " << hsaco_path; + std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate); + std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg(); + hsaco = std::vector(hsaco_file_size); + hsaco_file.seekg(0, std::ios::beg); + hsaco_file.read(reinterpret_cast(&hsaco[0]), hsaco_file_size); + hsaco_cache_.emplace(hsaco_path, hsaco); + return OkStatus(); + } + return xla::InternalErrorStrCat("Can't find Hsaco: ", hsaco_path); } // Emits the given module to HSA Code Object. target_machine is an initialized // TargetMachine for the AMDGPU target. +#ifdef TENSORFLOW_HSACO_USE_ROCM_LLVM + +StatusOr> EmitModuleToHsaco( + llvm::Module* module, llvm::TargetMachine* target_machine, + const std::string& optimized_ir_path, const std::string& isabin_path, + const std::string& hsaco_path, std::string& gcn_arch_name) { + std::string error_message; + std::vector tokens = absl::StrSplit(gcn_arch_name, ':'); + std::string gfx = tokens[0]; + // Locate llc. + std::string llc_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + auto llc_program = llvm::sys::findProgramByName("llc", {llc_path}); + if (!llc_program) { + return xla::InternalError("unable to find llc in PATH: %s", + llc_program.getError().message()); + } + std::vector llc_args{ + llvm_ir::AsStringRef("llc"), + llvm_ir::AsStringRef("-march=amdgcn"), + llvm_ir::AsStringRef( + absl::StrCat("-mcpu=", gfx)), + llvm_ir::AsStringRef("--amdgpu-kernarg-preload-count=16"), + llvm_ir::AsStringRef("-filetype=obj"), + llvm_ir::AsStringRef("-o"), + llvm_ir::AsStringRef(isabin_path), + llvm_ir::AsStringRef(optimized_ir_path), + }; + + int llc_result = + llvm::sys::ExecuteAndWait(*llc_program, llvm_ir::AsArrayRef(llc_args), + std::nullopt, {}, 0, 0, &error_message); + + if (llc_result) { + return xla::InternalError("llc execute fail: %s", error_message); + } + + // Locate lld. + // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after + // ROCm-Device-Libs PR. + std::string lld_path_1 = tsl::io::JoinPath(tsl::RocmRoot(), "hcc/bin"); + std::string lld_path_2 = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); + auto lld_program = + llvm::sys::findProgramByName("ld.lld", {lld_path_1, lld_path_2}); + if (!lld_program) { + return xla::InternalError("unable to find ld.lld in PATH: %s", + lld_program.getError().message()); + } + std::vector lld_args{ + llvm_ir::AsStringRef("ld.lld"), llvm_ir::AsStringRef("-flavor"), + llvm_ir::AsStringRef("gnu"), llvm_ir::AsStringRef("-shared"), + llvm_ir::AsStringRef(isabin_path), llvm_ir::AsStringRef("-o"), + llvm_ir::AsStringRef(hsaco_path), + }; + + int lld_result = + llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args), + std::nullopt, {}, 0, 0, &error_message); + + if (lld_result) { + return xla::InternalError("ld.lld execute fail: %s", error_message); + } + + // Read HSACO. + std::vector hsaco; + ReadHsaco(hsaco_path, hsaco); + return hsaco; +} + +#else + StatusOr> EmitModuleToHsaco( llvm::Module* module, llvm::TargetMachine* target_machine) { auto* env = tsl::Env::Default(); @@ -828,7 +999,19 @@ StatusOr> EmitModuleToHsaco( return hsaco; } +#endif // TENSORFLOW_HSACO_USE_ROCM_LLVM +#ifdef TENSORFLOW_HSACO_USE_ROCM_LLVM // Links ROCm-Device-Libs into the given module if the module needs it. +Status LinkROCDLIfNecessary(llvm::Module* module, std::string gcn_arch_name, + const std::string& rocdl_dir_path, + const std::string& ir_path, + const std::string& linked_ir_path, + const std::string& optimized_ir_path) { + return LinkWithBitcodeVector(module, + GetROCDLPaths(gcn_arch_name, rocdl_dir_path), + ir_path, linked_ir_path, optimized_ir_path); +} +#else Status LinkROCDLIfNecessary(llvm::Module* module, std::string gcn_arch_name, const std::string& rocdl_dir_path) { if (!CouldNeedDeviceBitcode(*module)) { @@ -838,7 +1021,42 @@ Status LinkROCDLIfNecessary(llvm::Module* module, std::string gcn_arch_name, return LinkWithBitcodeVector(module, GetROCDLPaths(gcn_arch_name, rocdl_dir_path)); } +#endif +#ifdef TENSORFLOW_HSACO_USE_ROCM_LLVM +Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, + const std::string& device_bitcode_dir_path, + const std::string& ir_path, + const std::string& linked_ir_path, + const std::string& optimized_ir_path) { + // Link the input module with ROCDL. + + auto compute_capability = + std::get_if(&gpu_version); + if (!compute_capability) { + return xla::InternalError("Incompatible compute capability was specified."); + } + + std::string gcn_arch_name = compute_capability->gcn_arch_name(); + TF_RETURN_IF_ERROR(LinkROCDLIfNecessary(module, gcn_arch_name, + device_bitcode_dir_path, ir_path, + linked_ir_path, optimized_ir_path)); + + // For rocm, we always enable flush to zero. (for cuda, this is determined + // via environemnt variables). This deceision was based on the observation + // Eugene had that the AMD GPU llvm backend has not picked up the atomic add + // instructions correctly without ftz enabled. We concluded that this should + // not has major impact as the hipcc path by default enables flush to zero for + // compilation. + for (llvm::Function& fn : *module) { + // may be necessary for the compiler to generate atomics (confirm!) + fn.addFnAttr("denormal-fp-math-f32", "preserve-sign"); + fn.addFnAttr("amdgpu-unsafe-fp-atomics", "true"); + } + return OkStatus(); +} +#else Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version, const HloModuleConfig& hlo_module_config, const std::string& device_bitcode_dir_path) { @@ -868,7 +1086,7 @@ Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version, return OkStatus(); } - +#endif // The following routine maps a feature token extracted from the // hipDeviceProp_t::gcnArchName string, and maps it to a valid feature_str // to be used for creating the AMDGPUTarget. @@ -884,7 +1102,8 @@ std::string MapGCNArchNameTokenToFeatureStr(const std::string& token, if (token == "sramecc+") { return "+sramecc"; } else if (token == "sramecc-") { - if(gfx == "gfx90a" || gfx == "gfx940" || gfx == "gfx941" || gfx == "gfx942") + if (gfx == "gfx90a" || gfx == "gfx940" || gfx == "gfx941" || + gfx == "gfx942") return ""; return "-sramecc"; } else if (token == "xnack+") { @@ -893,7 +1112,6 @@ std::string MapGCNArchNameTokenToFeatureStr(const std::string& token, return "-xnack"; } return ""; - } std::pair GetFeatureStrFromGCNArchName( @@ -939,14 +1157,13 @@ void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) { // Initialize the AMDGPU target; it's the only target we link with, so call // its specific initialization functions instead of the catch-all // InitializeAll*. -#if TENSORFLOW_USE_ROCM + + InitHsacoCacheDir(); LLVMInitializeAMDGPUTarget(); LLVMInitializeAMDGPUTargetInfo(); LLVMInitializeAMDGPUTargetMC(); LLVMInitializeAMDGPUAsmPrinter(); -#endif - llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); InitializePasses(registry); } @@ -954,18 +1171,21 @@ void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) { } // namespace namespace amdgpu { +#ifdef TENSORFLOW_USE_ROCM StatusOr> CompileToHsaco( llvm::Module* module, GpuVersion gpu_version, const HloModuleConfig& hlo_module_config, const std::string& rocdl_dir_path) { - static absl::once_flag backend_init_flag; - absl::call_once(backend_init_flag, AMDGPUBackendInit, hlo_module_config); + static std::once_flag backend_init_flag; + std::call_once(backend_init_flag, AMDGPUBackendInit, hlo_module_config); std::vector hsaco; std::unique_ptr target_machine; + std::string str; llvm::raw_string_ostream stream(str); stream << *module; + // Delete the first two lines, since they usually vary even when the rest of // the code is the same (but verify that they are what we expect). if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") { @@ -976,7 +1196,7 @@ StatusOr> CompileToHsaco( auto pos = str.find('\n'); if (pos != std::string::npos) str = str.substr(pos + 1); } - str += hlo_module_config.compilation_cache_key(); + // str += hlo_module_config.compilation_cache_key(); { tsl::profiler::TraceMe activity( [&] { return absl::StrCat("Compiling IR", module->getName().str()); }, @@ -990,30 +1210,64 @@ StatusOr> CompileToHsaco( "Incompatible compute capability was specified."); } + llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz"); + // Construct LLVM TargetMachine for AMDGPU. + std::unique_ptr target_machine = + AMDGPUGetTargetMachine(default_target_triple, gpu_version, + hlo_module_config); +#ifdef TENSORFLOW_HSACO_USE_ROCM_LLVM + std::string gcn_arch_name = compute_capability->gcn_arch_name(); - uint64_t hash; - if (HsacoCache::Find(str, hash, gcn_arch_name, hsaco)) { + std::string hsaco_filename = + absl::StrCat(std::hash{}(str), + gcn_arch_name, ".hsaco"); + std::string hsaco_path = + tsl::io::JoinPath(hsaco_cache_dir_, hsaco_filename); + + if (ReadHsaco(hsaco_path, hsaco).ok()) { VLOG(1) << "HSACO cache hit"; return hsaco; } VLOG(1) << "HSACO cache miss"; - bool dump_lls = false; - if (dump_lls) { - static int hsaco_count = 0; - std::string name = "/tmp/" + std::to_string(hsaco_count) + ".ll"; - hsaco_count++; - std::ofstream ofs(name); - ofs << str; - ofs.close(); + + auto* env = tsl::Env::Default(); + // Prepare filenames for all stages of compilation: + // IR, binary ISA, and HSACO. + std::string module_path; + if (!env->LocalTempFilename(&module_path)) { + return xla::InternalError( + "Could not get temporary filenames for modules."); } + std::string ir_path = absl::StrCat(module_path, ".ll"); - llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz"); - // Construct LLVM TargetMachine for AMDGPU. - std::unique_ptr target_machine = - AMDGPUGetTargetMachine(default_target_triple, gpu_version, - hlo_module_config); + std::string linked_ir_path = absl::StrCat(module_path, "-linked.ll"); + + std::string optimized_ir_path = absl::StrCat(module_path, "-opt.ll"); + + std::string isabin_path = absl::StrCat(module_path, ".o"); + + // Link with ROCm-Device-Libs, and optimize the LLVM module. + TF_RETURN_IF_ERROR(LinkAndOptimizeModule( + module, gpu_version, hlo_module_config, rocdl_dir_path, + AMDGPUTargetModuleLinker, default_target_triple, target_machine.get(), + kAMDGPUInlineThreshold, ir_path, linked_ir_path, optimized_ir_path)); + // Lower optimized LLVM module to HSA code object. + TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get(), + optimized_ir_path, isabin_path, + hsaco_path, gcn_arch_name)); + + std::async( + std::launch::async, + [](std::vector files) { + for (auto& file : files) { + tsl::Env::Default()->DeleteFile(file); + } + }, + std::vector{ir_path, linked_ir_path, optimized_ir_path, + isabin_path}); +#else // Link with ROCm-Device-Libs, and optimize the LLVM module. TF_RETURN_IF_ERROR(LinkAndOptimizeModule( module, gpu_version, hlo_module_config, rocdl_dir_path, @@ -1022,11 +1276,11 @@ StatusOr> CompileToHsaco( // Lower optimized LLVM module to HSA code object. TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get())); - HsacoCache::Add(str, hash, gcn_arch_name, hsaco); +#endif // TENSORFLOW_HSACO_USE_ROCM_LLVM } return hsaco; } - +#endif // TENSORFLOW_USE_ROCM } // namespace amdgpu } // namespace gpu From 61f9c5bbc30530b02486a7b13a5fdedcaebf12f8 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 03:02:35 +0000 Subject: [PATCH 2/9] xla/gpu: Use full csv string as cache key --- .../compiler/xla/service/gpu/autotuner_util.cc | 15 +++++++++------ .../xla/service/gpu/gemm_algorithm_picker.cc | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/autotuner_util.cc b/tensorflow/compiler/xla/service/gpu/autotuner_util.cc index fad209c33c8a04..c262f1c9c8bf29 100644 --- a/tensorflow/compiler/xla/service/gpu/autotuner_util.cc +++ b/tensorflow/compiler/xla/service/gpu/autotuner_util.cc @@ -56,16 +56,19 @@ static auto& autotune_cache ABSL_GUARDED_BY(autotune_cache_mu) = namespace { -void CSVLegend(std::ostream& os) { +void CSVLegend(std::ostream& os, bool full_string=false) { os << kCsvComment << " m" << kCsvSep << "n" << kCsvSep << "k" << kCsvSep << "batch_count" << kCsvSep << "trans_a" << kCsvSep - << "trans_b" << kCsvSep - << "type_a" << kCsvSep << "type_b" << kCsvSep + << "trans_b" << kCsvSep << "type_a" << kCsvSep << "type_b" << kCsvSep << "type_c" << kCsvSep << "lda" << kCsvSep << "ldb" << kCsvSep << "ldc" << kCsvSep << "stride_a" << kCsvSep - << "stride_b" << kCsvSep << "stride_c" << kCsvSep - << "alg_index" << std::endl; + << "stride_b" << kCsvSep << "stride_c"; + if (full_string) { + os << kCsvSep << "alpha_re" << kCsvSep << "alpha_im" << kCsvSep + << "beta" << kCsvSep << "epilogue"; + } + os << kCsvSep << "alg_index" << std::endl; } } // namespace @@ -89,7 +92,7 @@ void CSVLegend(std::ostream& os) { if (!s_dump_fs->is_open()) { LOG(WARNING) << "Unable to open: " << dump_path << " for writing!"; } - CSVLegend(*s_dump_fs); + CSVLegend(*s_dump_fs, true); } *s_dump_fs << key.Get() << kCsvSep << it->second << std::endl; } diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 0cd40a61e4f25f..f60d786ffedc36 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -367,7 +367,7 @@ StatusOr RunOnInstruction(HloInstruction* gemm, GemmAutotuner autotuner(config); TF_ASSIGN_OR_RETURN(auto new_algorithm, - AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, false), config, + AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, true), config, [&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto algo, autotuner(gemm, gemm_config)); return algo.has_gemm() ? algo.gemm().algorithm() : se::blas::kDefaultAlgorithm; @@ -410,7 +410,7 @@ StatusOr GemmAlgorithmPicker::RunStandalone( GemmAutotuner autotuner(config_); GemmConfig gemm_config{cfg}; - return AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, false), config_, + return AutotunerUtil::Autotune(se::gpu::ToCSVString(gemm_config, true), config_, [&]() -> StatusOr { TF_ASSIGN_OR_RETURN(auto algo, autotuner(gemm_config, std::move(input_shapes), output_shape, debug_options)); From 2875d46618d5a2afe76b9c7b185b75b8ad028f5a Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 03:04:25 +0000 Subject: [PATCH 3/9] xla/gpu: Use default algorithm as fallback when can't find cache algorithm --- .../compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc b/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc index 984244c0892155..b7e8f3176dc404 100644 --- a/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gpublas_lt_matmul_thunk.cc @@ -173,7 +173,9 @@ auto CublasLtMatmulThunk::GetCachedMatmulPlan( return std::move(plan); } } - return InternalError("Wrong algorithm ID: %d", algorithm_id); + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[0])); + LOG(WARNING) << "Wrong algorithm ID: " << algorithm_id << " use default instead."; + return std::move(plan); }; return cache.GetOrCreate(canonical_hlo_, create); } From 76f797616fc32c20264827bc13b2fff67c03af95 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 03:05:59 +0000 Subject: [PATCH 4/9] Use full csv string --- tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc index 195e1161a3aa3a..69f5b8f401f159 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -264,8 +264,7 @@ std::string ToCSVString(const GemmConfig& cfg, bool full_string) { if (full_string) { // NOTE: epilogue is required for MatmulPlan caching ! - oss //<< kCsvSep << cfg.alpha << kCsvSep << cfg.beta - << kCsvSep << (int64_t)cfg.epilogue; + oss << kCsvSep << cfg.alpha.real() << kCsvSep << cfg.alpha.imag() << kCsvSep << cfg.beta << kCsvSep << (int64_t)cfg.epilogue; } return oss.str(); From 37b9a1af5fe7f6a4710612664b23a74cf5d165f6 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 03:12:53 +0000 Subject: [PATCH 5/9] update datalayout to avoid warnings --- tensorflow/compiler/xla/service/gpu/target_constants.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/target_constants.h b/tensorflow/compiler/xla/service/gpu/target_constants.h index cec1182f0a5fb5..bf62b07f0ef775 100644 --- a/tensorflow/compiler/xla/service/gpu/target_constants.h +++ b/tensorflow/compiler/xla/service/gpu/target_constants.h @@ -46,9 +46,7 @@ inline const char* TargetTriple() { // The data layout of the emitted module. inline const char* DataLayout() { static constexpr char kDataLayout[] = - "e-p:64:64-p1:64:64-p2:64:64-p3:32:32-p4:32:32-p5:32:32" - "-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128" - "-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-A5"; + "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"; return kDataLayout; } From 465dd6210168fe7956a70ca3274a583658f8131f Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 03:20:25 +0000 Subject: [PATCH 6/9] add cs8 dockerfile --- tensorflow/tools/ci_build/Dockerfile.cs8.rocm | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 tensorflow/tools/ci_build/Dockerfile.cs8.rocm diff --git a/tensorflow/tools/ci_build/Dockerfile.cs8.rocm b/tensorflow/tools/ci_build/Dockerfile.cs8.rocm new file mode 100644 index 00000000000000..0372a3772537fc --- /dev/null +++ b/tensorflow/tools/ci_build/Dockerfile.cs8.rocm @@ -0,0 +1,158 @@ +# This Dockerfile provides a starting point for a ROCm installation of +# MIOpen and tensorflow. +FROM almalinux:8 +MAINTAINER Jeff Poznanovic + +ARG RPM_ROCM_REPO=https://repo.radeon.com/rocm/rhel8/.yum_6.3.0.1/main +ARG ROCM_PATH=/opt/rocm-6.3.0.1 + +ENV ROCM_PATH=$ROCM_PATH +ENV DEBIAN_FRONTEND noninteractive +ENV TF_NEED_ROCM 1 +ENV GCC_HOST_COMPILER_PATH=/opt/rocm/bin/amdclang +ENV HOME /root/ +RUN dnf update -y && dnf install -y epel-release && dnf install -y elrepo-release && dnf config-manager --set-enabled powertools +# Setup the build_system repo +RUN echo -e "[build_system]\nname=ROCm\nbaseurl=https://repo.almalinux.org/build_system/8/x86_64/\nenabled=1\ngpgcheck=0" >/etc/yum.repos.d/build_system.repo +RUN dnf group install -y "Development Tools" + +RUN bin/bash -c 'echo -e "[ROCm]\nname=ROCm\nbaseurl=$RPM_ROCM_REPO\nenabled=1\ngpgcheck=0" >>/etc/yum.repos.d/rocm.repo' +RUN bin/bash -c 'echo -e "[amdgpu]\nname=amdgpu\nbaseurl=https://repo.radeon.com/amdgpu/.6.3.0.1/rhel/8.8/main/x86_64/\nenabled=1\ngpgcheck=0" >> /etc/yum.repos.d/amdgpu.repo' + +RUN dnf clean all +RUN dnf update -y + +# Install misc pkgs +RUN dnf --enablerepo=extras,epel,elrepo,powertools,build_system install -y \ + epel-release \ + openssl-devel \ + libffi-devel \ + hdf5-devel \ + wget \ + make \ + patch \ + zlib-devel \ + bzip2 \ + bzip2-devel \ + readline \ + readline-devel \ + sqlite \ + sqlite-devel \ + openssl-devel \ + tk-devel \ + xz-devel + +RUN dnf --enablerepo=extras,epel,elrepo,powertools,build_system install -y \ + bc \ + bridge-utils \ + cmake \ + cmake3 \ + devscripts \ + dkms \ + doxygen \ + dpkg \ + dpkg-dev \ + dpkg-perl \ + elfutils-libelf-devel \ + expect \ + file \ + gettext \ + gcc-c++ \ + git \ + libgcc \ + ncurses \ + ncurses-base \ + ncurses-libs \ + numactl-devel \ + numactl-libs \ + libssh \ + libunwind-devel \ + libunwind \ + llvm \ + llvm-libs \ + make \ + openssl \ + openssl-libs \ + openssh \ + openssh-clients \ + pciutils \ + pciutils-devel \ + pciutils-libs \ + java-11-openjdk-devel \ + patchelf\ + pkgconfig \ + npth \ + qemu-kvm \ + re2c \ + rpm \ + rpm-build \ + subversion \ + sudo \ + wget\ + kernel-devel-uname-r + +RUN dnf --enablerepo=extras,build_system install -y \ + libdrm-amdgpu \ + rocm-dev \ + rocm-ml-sdk \ + miopen-hip \ + miopen-hip-devel \ + rocblas \ + rocblas-devel \ + rocsolver-devel \ + rocrand-devel \ + rocfft-devel \ + hipfft-devel \ + hipblas-devel \ + rocprim-devel \ + hipcub-devel \ + rccl-devel \ + hipsparse-devel \ + hipsolver-devel \ + hipblas-common-devel \ + rocm-llvm-devel \ + boost-devel + +RUN dnf --enablerepo=extras,epel,elrepo,powertools,build_system install -y \ + python3.11 \ + python3.11-devel \ + python3.11-pip \ + python3.11-wheel + +RUN ln -sf /usr/bin/python3.11 /usr/bin/python3 +RUN ln -sf /usr/bin/python3 /usr/bin/python +RUN ln -sf /usr/bin/python3.11 /etc/alternatives/python3 + +RUN python3 -m ensurepip +RUN pip install joblib numpy==1.24.0 requests packaging + +ENV OPENCL_ROOT=$ROCM_PATH/opencl +ENV PATH="$ROCM_PATH/bin:${PATH}" +ENV PATH="$OPENCL_ROOT/bin:${PATH}" + +# Workaround, explicitly add symbolic link to /opt/rocm +RUN touch ${ROCM_PATH}/.info/version +RUN bash -c 'ln -s ${ROCM_PATH} /opt/rocm' + +# Add target file to help determine which device(s) to build for +RUN bash -c 'echo -e "gfx942\ngfx90a\n" >> ${ROCM_PATH}/bin/target.lst' + +# Setup environment variables, and add those environment variables at the end of ~/.bashrc +ARG PATH=$HCC_HOME/bin:$HIP_PATH/bin:$PATH + +COPY install/*.sh /install/ + +SHELL ["/bin/bash", "-c"] +RUN /install/install_bazel.sh +RUN /install/install_golang.sh + +# Configure the build for our CUDA configuration. +ENV TF_NEED_ROCM 1 + +# This is a temporary workaround to fix Out-Of-Memory errors we are running into with XLA perf tests +# By default, HIP runtime "hides" 256MB from the TF Runtime, but with recent changes (update to ROCm2.3, dynamic loading of roc* libs, et al) +# it seems that we need to up the threshold slightly to 320MB +ENV HIP_HIDDEN_FREE_MEM=320 + +#We'll be using a custom CK build in this branch +# RUN bash -c 'mv ${ROCM_PATH}/include/ck ${ROCM_PATH}/include/ck-back' \ No newline at end of file From 88529bf90d43401f5c87cfc573a32fef8396d545 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 03:25:35 +0000 Subject: [PATCH 7/9] Enable clang as a compiler option --- third_party/gpus/crosstool/BUILD.rocm.tpl | 4 ++-- .../bin/crosstool_wrapper_driver_rocm.tpl | 20 ++++++++++++++++--- .../hipcc_cc_toolchain_config.bzl.tpl | 2 +- third_party/gpus/rocm_configure.bzl | 4 +++- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/gpus/crosstool/BUILD.rocm.tpl index a742cfcd208ec1..264b7b52b67d54 100644 --- a/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -87,14 +87,14 @@ cc_toolchain_config( "-fuse-ld=gold", "-Wl,-no-as-needed", "-Wl,-z,relro,-z,now", - "-pass-exit-codes", + # "-pass-exit-codes", "-lstdc++", "-lm", ], link_libs = [], opt_link_flags = [], unfiltered_compile_flags = [ - "-fno-canonical-system-headers", + # "-fno-canonical-system-headers", "-Wno-builtin-macro-redefined", "-D__DATE__=\"redacted\"", "-D__TIMESTAMP__=\"redacted\"", diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index 794a80cbf06fdb..c3218cec62c9e4 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -75,7 +75,9 @@ def GetHostCompilerOptions(argv): parser.add_argument('-iquote', nargs='*', action='append') parser.add_argument('--sysroot', nargs=1) parser.add_argument('-g', nargs='*', action='append') - parser.add_argument('-fno-canonical-system-headers', action='store_true') + parser.add_argument('-no-canonical-prefixes', action='store_true') + parser.add_argument('-Wno-unused-variable', action='store_true') + parser.add_argument('-Wno-unused-but-set-variable', action='store_true') args, _ = parser.parse_known_args(argv) @@ -87,10 +89,16 @@ def GetHostCompilerOptions(argv): opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) if args.g: opts += ' -g' + ' -g'.join(sum(args.g, [])) - if args.fno_canonical_system_headers: + if args.no_canonical_prefixes: opts += ' -no-canonical-prefixes' if args.sysroot: opts += ' --sysroot ' + args.sysroot[0] + if args.Wno_unused_variable: + opts += ' -Wno-unused-variable' + + if args.Wno_unused_but_set_variable: + opts += ' -Wno-unused-but-set-variable' + return opts @@ -282,7 +290,13 @@ def main(): if not flag.startswith(('--rocm_log'))] # XXX: SE codes need to be built with gcc, but need this macro defined - cpu_compiler_flags.append("-D__HIP_PLATFORM_HCC__") + cpu_compiler_flags.append("-D__HIP_PLATFORM_AMD__") + cpu_compiler_flags.append('-L' + HIP_RUNTIME_PATH) + cpu_compiler_flags.append('-Wl,-rpath=' + HIP_RUNTIME_PATH) + cpu_compiler_flags.append('-l' + HIP_RUNTIME_LIBRARY) + cpu_compiler_flags.append("-lrt") + cpu_compiler_flags.append("-Wno-unused-command-line-argument") + cpu_compiler_flags.append("-Wno-gnu-offsetof-extensions") if VERBOSE: print(' '.join([CPU_COMPILER] + cpu_compiler_flags)) return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) diff --git a/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl b/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl index e0541defa34687..17741367aff556 100644 --- a/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl +++ b/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl @@ -1046,7 +1046,7 @@ def _impl(ctx): flag_group( flags = [ "-no-canonical-prefixes", - "-fno-canonical-system-headers", + #"-fno-canonical-system-headers", ] ), ], diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index a568b1c5f517e8..48e0ea06a6fb9e 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -715,12 +715,14 @@ def _create_local_rocm_repository(repository_ctx): # .d file - given that includes that are prefixed with "../" multiple # time quickly grow longer than the root of the tree, this can lead to # bazel's header check failing. - rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" + rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "" rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ "-DTENSORFLOW_USE_ROCM=1", "-D__HIP_PLATFORM_AMD__", "-DEIGEN_USE_HIP", + "-Wno-unused-but-set-variable", + "-Wno-c++11-narrowing", ]) rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" From 10ed244bb204421d340b1e2fcefb2ff228fdaee1 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 03:30:14 +0000 Subject: [PATCH 8/9] refactor build_rocm_python3 --- build_rocm_python3 | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/build_rocm_python3 b/build_rocm_python3 index 27eedd2b916313..e5bacfc074cd41 100755 --- a/build_rocm_python3 +++ b/build_rocm_python3 @@ -47,15 +47,15 @@ if [ -f /usertools/rocm.bazelrc ]; then if [[ -n $nightly ]]; then # Remove any previous builds and build nightly rm -f $TF_PKG_LOC/tf_nightly_rocm*.whl - python3 tensorflow/tools/ci_build/update_version.py --nightly --rocm_version && - bazel --bazelrc=/usertools/rocm.bazelrc build $RESOURCE_OPTION --config=rocm --action_env=TF_PYTHON_VERSION=$PYTHON_VERSION tensorflow/tools/pip_package:build_pip_package --verbose_failures && + #python3 tensorflow/tools/ci_build/update_version.py --nightly --rocm_version && + bazel --bazelrc=/usertools/rocm.bazelrc build $RESOURCE_OPTION --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-DTENSORFLOW_HSACO_USE_ROCM_LLVM" --config=v1 --config=rocm --action_env=TF_PYTHON_VERSION=$PYTHON_VERSION tensorflow/tools/pip_package:build_pip_package --verbose_failures && ./bazel-bin/tensorflow/tools/pip_package/build_pip_package $TF_PKG_LOC --rocm --nightly_flag && pip3 install --upgrade $TF_PKG_LOC/tf_nightly_rocm*.whl else # Remove any previous builds and build release rm -f $TF_PKG_LOC/tensorflow*.whl - python3 tensorflow/tools/ci_build/update_version.py --rocm_version && - bazel --bazelrc=/usertools/rocm.bazelrc build $RESOURCE_OPTION --config=rocm --action_env=TF_PYTHON_VERSION=$PYTHON_VERSION tensorflow/tools/pip_package:build_pip_package --verbose_failures && + #python3 tensorflow/tools/ci_build/update_version.py --rocm_version && + bazel --bazelrc=/usertools/rocm.bazelrc build $RESOURCE_OPTION --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-DTENSORFLOW_HSACO_USE_ROCM_LLVM" --config=v1 --config=rocm --action_env=TF_PYTHON_VERSION=$PYTHON_VERSION tensorflow/tools/pip_package:build_pip_package --verbose_failures && ./bazel-bin/tensorflow/tools/pip_package/build_pip_package $TF_PKG_LOC --rocm --project_name tensorflow_rocm && pip3 install --upgrade $TF_PKG_LOC/tensorflow*.whl fi @@ -66,13 +66,13 @@ else if [[ -n $nightly ]]; then # Remove any previous builds and build nightly rm -f $TF_PKG_LOC/tf_nightly_rocm*.whl - bazel build $RESOURCE_OPTION --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures && + bazel build $RESOURCE_OPTION --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-DTENSORFLOW_HSACO_USE_ROCM_LLVM" --config=v1 --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures && bazel-bin/tensorflow/tools/pip_package/build_pip_package $TF_PKG_LOC --rocm --nightly_flag && pip3 install --upgrade $TF_PKG_LOC/tf_nightly_rocm*.whl else # Remove any previous builds and build release rm -f $TF_PKG_LOC/tensorflow*.whl - bazel build $RESOURCE_OPTION --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures && + bazel build $RESOURCE_OPTION --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-DTENSORFLOW_HSACO_USE_ROCM_LLVM" --config=v1 --config=opt --config=rocm //tensorflow/tools/pip_package:build_pip_package --verbose_failures && bazel-bin/tensorflow/tools/pip_package/build_pip_package $TF_PKG_LOC --rocm && pip3 install --upgrade $TF_PKG_LOC/tensorflow*.whl fi From 180c0cf6ec09e8c09985f875b7078eeaeabe13a1 Mon Sep 17 00:00:00 2001 From: fsx950223 Date: Thu, 21 Nov 2024 03:34:13 +0000 Subject: [PATCH 9/9] add hlo benchmark --- tensorflow/tools/hlo_benchmark/README.md | 13 ++++ .../tools/hlo_benchmark/hlo_estimate.py | 75 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 tensorflow/tools/hlo_benchmark/README.md create mode 100644 tensorflow/tools/hlo_benchmark/hlo_estimate.py diff --git a/tensorflow/tools/hlo_benchmark/README.md b/tensorflow/tools/hlo_benchmark/README.md new file mode 100644 index 00000000000000..ec3fef94c49f47 --- /dev/null +++ b/tensorflow/tools/hlo_benchmark/README.md @@ -0,0 +1,13 @@ +# A script about how to compute HLO Module FLOPS. +## Build +``` +bazel build --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --config=v1 --config=opt --config=rocm tensorflow/compiler/xla/tools:run_hlo_module tensorflow/compiler/xla/tools:compute_cost --verbose_failures +``` + +## Usage +``` +python tensorflow/tools/hlo_benchmark/hlo_estimate.py --hlo=tensorflow/compiler/xla/tests/*.gfx942_gpu_after_optimizations.txt --output=result.txt +``` + +## Example Output +slow_xla_sample_v2/module_5629.cluster_4221__XlaCompiledKernel_true__XlaHasReferenceVars_true__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.785.gfx942_gpu_after_optimizations.txt 4332.0634140346565 GFLOPS/s 0.0008137 s \ No newline at end of file diff --git a/tensorflow/tools/hlo_benchmark/hlo_estimate.py b/tensorflow/tools/hlo_benchmark/hlo_estimate.py new file mode 100644 index 00000000000000..3f6a0163f76a23 --- /dev/null +++ b/tensorflow/tools/hlo_benchmark/hlo_estimate.py @@ -0,0 +1,75 @@ +import subprocess +import glob +import re +import argparse + +# Paths to the input and output files +parser = argparse.ArgumentParser(description="""Generate Tensile config file""") + +parser.add_argument( + "--hlo", + type=str, + help="Glob path to hlo modules") + +parser.add_argument( + "--output", + type=str, + help="Output file path") + +parser.add_argument( + "--warmup", type=int, default=10, + help="Warmup iterations") + +parser.add_argument( + "--iters", type=int, default=10, + help="Max tuning iterations") + +args = parser.parse_args() + + +# PATH = "/home/sixifang/tensorflow-upstream/bubble_test_xla_dump/*.gfx90a_gpu_after_optimizations.txt" +# OUTPUT_FILE = "result.txt" +PATH = args.hlo +OUTPUT_FILE = args.output + +HLO_BENCH_RE = r"execution time for runner ROCM: (?P