Skip to content

Commit

Permalink
Reset hipblaslt env var on handle destruction (#2844)
Browse files Browse the repository at this point in the history
  • Loading branch information
daineAMD authored Jan 29, 2025
1 parent 2fbbb10 commit 8f2e0c5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 17 deletions.
29 changes: 27 additions & 2 deletions clients/common/client_utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,30 @@ rocblas_local_handle::rocblas_local_handle()
}

rocblas_local_handle::rocblas_local_handle(const Arguments& arg)
: rocblas_local_handle()
{
if(arg.use_hipblaslt >= 0)
{
auto hipblaslt_env = getenv("ROCBLAS_USE_HIPBLASLT");
if(hipblaslt_env)
m_hipblaslt_saved_status = std::string(hipblaslt_env);
m_hipblaslt_env_set = true;
setenv("ROCBLAS_USE_HIPBLASLT", std::to_string(arg.use_hipblaslt).c_str(), true);
}

auto status = rocblas_create_handle(&m_handle);
if(status != rocblas_status_success)
throw std::runtime_error(rocblas_status_to_string(status));

#ifdef GOOGLE_TEST
if(t_set_stream_callback)
{
(*t_set_stream_callback)(m_handle);
t_set_stream_callback.reset();
}
#endif

// Set the atomics mode
auto status = rocblas_set_atomics_mode(m_handle, arg.atomics_mode);
status = rocblas_set_atomics_mode(m_handle, arg.atomics_mode);

// The check_numerics mode conditional defeat with "rocblas_check_numerics_mode_no_check"
// Defeat check numerics when initializing any data with NaN with due to alpha or beta having NaN flags,
Expand Down Expand Up @@ -438,6 +458,11 @@ rocblas_local_handle::~rocblas_local_handle()
if(m_memory)
(hipFree)(m_memory);

if(m_hipblaslt_env_set)
{
setenv("ROCBLAS_USE_HIPBLASLT", m_hipblaslt_saved_status.c_str(), true);
}

rocblas_destroy_handle(m_handle);
}

Expand Down
10 changes: 10 additions & 0 deletions clients/gtest/gemm_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3988,6 +3988,16 @@ Tests:
use_hipblaslt: 1
os_flags: LINUX

- name: non_hipblaslt_f16
category: pre_checkin
function: gemm_ex
precision: *hpa_half_precision
transA_transB: *transA_transB_range
alpha_beta: *alpha_beta_range
matrix_size: *medium_matrix_size_range
use_hipblaslt: 0
os_flags: LINUX


# Commented out as the category known_bug has currently no good purpose
# This is category known_bug until the sizes are supported by Tensile,
Expand Down
2 changes: 2 additions & 0 deletions clients/include/client_utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class rocblas_local_handle
void* m_memory{nullptr};
hipStream_t m_graph_stream{nullptr};
hipStream_t m_old_stream{nullptr};
std::string m_hipblaslt_saved_status = "";
bool m_hipblaslt_env_set{false};

void rocblas_stream_begin_capture();
void rocblas_stream_end_capture();
Expand Down
16 changes: 1 addition & 15 deletions clients/include/type_dispatch.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* ************************************************************************
* Copyright (C) 2018-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2018-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -440,20 +440,6 @@ auto rocblas_gemv_batched_and_strided_batched_dispatch(const Arguments& arg)
template <template <typename...> class TEST>
auto rocblas_gemm_dispatch(const Arguments& arg)
{
int setenv_status;
if(arg.use_hipblaslt != -1)
{
setenv_status
= setenv("ROCBLAS_USE_HIPBLASLT", std::to_string(arg.use_hipblaslt).c_str(), true);
}
else
{
setenv_status = unsetenv("ROCBLAS_USE_HIPBLASLT");
}
#ifdef GOOGLE_TEST
EXPECT_EQ(setenv_status, 0);
#endif

const auto Ti = arg.a_type, To = arg.c_type, Tc = arg.compute_type;
const auto Tc_new = arg.composite_compute_type;

Expand Down

0 comments on commit 8f2e0c5

Please sign in to comment.