diff --git a/clients/common/client_utility.cpp b/clients/common/client_utility.cpp index a077dd20e..01342a3c1 100644 --- a/clients/common/client_utility.cpp +++ b/clients/common/client_utility.cpp @@ -382,10 +382,30 @@ rocblas_local_handle::rocblas_local_handle() } rocblas_local_handle::rocblas_local_handle(const Arguments& arg) - : rocblas_local_handle() { + if(arg.use_hipblaslt >= 0) + { + auto hipblaslt_env = getenv("ROCBLAS_USE_HIPBLASLT"); + if(hipblaslt_env) + m_hipblaslt_saved_status = std::string(hipblaslt_env); + m_hipblaslt_env_set = true; + setenv("ROCBLAS_USE_HIPBLASLT", std::to_string(arg.use_hipblaslt).c_str(), true); + } + + auto status = rocblas_create_handle(&m_handle); + if(status != rocblas_status_success) + throw std::runtime_error(rocblas_status_to_string(status)); + +#ifdef GOOGLE_TEST + if(t_set_stream_callback) + { + (*t_set_stream_callback)(m_handle); + t_set_stream_callback.reset(); + } +#endif + // Set the atomics mode - auto status = rocblas_set_atomics_mode(m_handle, arg.atomics_mode); + status = rocblas_set_atomics_mode(m_handle, arg.atomics_mode); // The check_numerics mode conditional defeat with "rocblas_check_numerics_mode_no_check" // Defeat check numerics when initializing any data with NaN with due to alpha or beta having NaN flags, @@ -438,6 +458,11 @@ rocblas_local_handle::~rocblas_local_handle() if(m_memory) (hipFree)(m_memory); + if(m_hipblaslt_env_set) + { + setenv("ROCBLAS_USE_HIPBLASLT", m_hipblaslt_saved_status.c_str(), true); + } + rocblas_destroy_handle(m_handle); } diff --git a/clients/gtest/gemm_gtest.yaml b/clients/gtest/gemm_gtest.yaml index a3a7173c8..4a0b1843b 100644 --- a/clients/gtest/gemm_gtest.yaml +++ b/clients/gtest/gemm_gtest.yaml @@ -3988,6 +3988,16 @@ Tests: use_hipblaslt: 1 os_flags: LINUX +- name: non_hipblaslt_f16 + category: pre_checkin + function: gemm_ex + precision: *hpa_half_precision + transA_transB: *transA_transB_range + alpha_beta: *alpha_beta_range + matrix_size: *medium_matrix_size_range + use_hipblaslt: 0 + os_flags: LINUX + # Commented out as the category known_bug has currently no good purpose # This is category known_bug until the sizes are supported by Tensile, diff --git a/clients/include/client_utility.hpp b/clients/include/client_utility.hpp index 782fe1064..b4ca613e6 100644 --- a/clients/include/client_utility.hpp +++ b/clients/include/client_utility.hpp @@ -116,6 +116,8 @@ class rocblas_local_handle void* m_memory{nullptr}; hipStream_t m_graph_stream{nullptr}; hipStream_t m_old_stream{nullptr}; + std::string m_hipblaslt_saved_status = ""; + bool m_hipblaslt_env_set{false}; void rocblas_stream_begin_capture(); void rocblas_stream_end_capture(); diff --git a/clients/include/type_dispatch.hpp b/clients/include/type_dispatch.hpp index 43e567084..45def8c4d 100644 --- a/clients/include/type_dispatch.hpp +++ b/clients/include/type_dispatch.hpp @@ -1,5 +1,5 @@ /* ************************************************************************ - * Copyright (C) 2018-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2018-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -440,20 +440,6 @@ auto rocblas_gemv_batched_and_strided_batched_dispatch(const Arguments& arg) template