diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp index 2316870f96..a413ad1eb0 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp @@ -52,7 +52,7 @@ void bounds_check_indices_cpu( Tensor& indices, Tensor& offsets, int64_t bounds_check_mode_, - Tensor& warning, + [[maybe_unused]] Tensor& warning, const std::optional& weights, const std::optional& B_offsets, const int64_t max_B, @@ -71,15 +71,11 @@ void bounds_check_indices_cpu( vbe ? B_offsets.value().data_ptr() : nullptr; auto bounds_check_mode = static_cast(bounds_check_mode_); - if (bounds_check_mode == BoundsCheckMode::WARNING) { - warning.zero_(); - } const int32_t T = rows_per_table.size(0); const int32_t total_B = offsets.size(0) - 1; const int32_t B = total_B / T; const auto rows_per_table_acc = rows_per_table.accessor(); - auto warning_acc = warning.data_ptr(); AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices_cpu", [&] { [[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "bounds_check_indices_cpu"; @@ -107,16 +103,14 @@ void bounds_check_indices_cpu( TORCH_CHECK(num_indices == offsets_acc[total_B]); } else if (bounds_check_mode == BoundsCheckMode::WARNING) { if (num_indices != offsets_acc[total_B]) { - if (__sync_fetch_and_add(&warning_acc[0], 1) == 0) { - LOG(ERROR) - << "EmbeddingBoundsCheck (VBE " << (vbe ? "true" : "false") - << "): The last element in offsets is incorrect for " - << "total batch size " << (vbe ? "total_B" : "B") << ": " - << (vbe ? total_B : B) << ", total table num T: " << T - << ", last element in offsets: " << offsets_acc[total_B] - << ", indices size: " << num_indices - << ". Setting the last element in offsets to be indices size."; - } + LOG(ERROR) + << "EmbeddingBoundsCheck (VBE " << (vbe ? "true" : "false") + << "): The last element in offsets is incorrect for " + << "total batch size " << (vbe ? "total_B" : "B") << ": " + << (vbe ? total_B : B) << ", total table num T: " << T + << ", last element in offsets: " << offsets_acc[total_B] + << ", indices size: " << num_indices + << ". Setting the last element in offsets to be indices size."; offsets_acc[total_B] = num_indices; } } else if (bounds_check_mode == BoundsCheckMode::IGNORE) { @@ -124,6 +118,7 @@ void bounds_check_indices_cpu( offsets_acc[total_B] = num_indices; } } + bool warning_logged = false; for (const auto t : c10::irange(T)) { auto num_rows = rows_per_table_acc[t]; auto B_begin = vbe ? B_offsets_ptr[t] : t * B; @@ -138,7 +133,7 @@ void bounds_check_indices_cpu( } else if (bounds_check_mode == BoundsCheckMode::WARNING) { if (indices_start < 0 || indices_start > indices_end || indices_end > num_indices) { - if (__sync_fetch_and_add(&warning_acc[0], 1) == 0) { + if (!warning_logged) { LOG(ERROR) << "EmbeddingBoundsCheck (VBE " << (vbe ? "true" : "false") << "): (at least one) Out of bounds access for batch: " << b @@ -146,6 +141,7 @@ void bounds_check_indices_cpu( << ", indices_end: " << indices_end << ", num_indices: " << num_indices << ". Setting indices_start and indices_end within the range"; + warning_logged = true; } adjust_offset_cpu( indices_start, @@ -175,13 +171,14 @@ void bounds_check_indices_cpu( TORCH_CHECK(idx < num_rows); } else if (bounds_check_mode == BoundsCheckMode::WARNING) { if (idx < 0 || idx >= num_rows) { - if (__sync_fetch_and_add(&warning_acc[0], 1) == 0) { + if (!warning_logged) { LOG(ERROR) << "EmbeddingBoundsCheck (VBE " << (vbe ? "true" : "false") << "): (at least one) Out of bounds access for batch: " << b << ", table: " << t << ", bag element: " << l << ", idx: " << idx << ", num_rows: " << num_rows << ". Setting idx to zero."; + warning_logged = true; } indices_acc[indices_start + l] = 0; }