Skip to content

Commit

Permalink
Fix issues with buffer length when getting brand name
Browse files Browse the repository at this point in the history
* Specifically, address case when brand name is longer than buffer
provided

* Also, slightly modify prototype to match similar, existing APIs.

* Address some cpplint issues.

Change-Id: Iaf77304e23085123e88f301e4b33bc4e6be2a225
  • Loading branch information
Chris Freehill authored and Chris Freehill committed Aug 26, 2019
1 parent 7f2d970 commit 01e0800
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
2 changes: 1 addition & 1 deletion include/rocm_smi/rocm_smi.h
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ rsmi_status_t rsmi_dev_name_get(uint32_t dv_ind, char *name, size_t len);
* @retval ::RSMI_STATUS_SUCCESS is returned upon successful call.
*
*/
rsmi_status_t rsmi_dev_brand_get(uint32_t dv_ind, char *brand, size_t len);
rsmi_status_t rsmi_dev_brand_get(uint32_t dv_ind, char *brand, uint32_t len);

/**
* @brief Get the name string for a give vendor ID
Expand Down
17 changes: 11 additions & 6 deletions src/rocm_smi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1423,14 +1423,13 @@ rsmi_dev_name_get(uint32_t dv_ind, char *name, size_t len) {
}

rsmi_status_t
rsmi_dev_brand_get(uint32_t dv_ind, char *brand, size_t len) {
rsmi_dev_brand_get(uint32_t dv_ind, char *brand, uint32_t len) {
GET_DEV_FROM_INDX
// Return 'invalid args' if arguments are invalid
if (brand == nullptr || len == 0){
if (brand == nullptr || len == 0) {
return RSMI_STATUS_INVALID_ARGS;
}
std::map<std::string, std::string> brand_names =
{
std::map<std::string, std::string> brand_names = {
{"D05121", "mi25"},
{"D05131", "mi25"},
{"D05133", "mi25"},
Expand All @@ -1447,11 +1446,17 @@ rsmi_dev_brand_get(uint32_t dv_ind, char *brand, size_t len) {
return errno_to_rsmi_status(ret);
}
if (vbios_value.length() == 16) {
sku_value = vbios_value.substr(4,6);
sku_value = vbios_value.substr(4, 6);
// Find the brand name using sku_value
it = brand_names.find(sku_value);
if (it != brand_names.end()) {
strcpy(brand, it->second.c_str());
uint32_t ln = it->second.copy(brand, len);
brand[std::min(len - 1, ln)] = '\0';

if (len < (it->second.size() + 1)) {
return RSMI_STATUS_INSUFFICIENT_SIZE;
}

return RSMI_STATUS_SUCCESS;
}
}
Expand Down

0 comments on commit 01e0800

Please sign in to comment.