Skip to content

Commit

Permalink
Add safe-guard for masked version
Browse files Browse the repository at this point in the history
  • Loading branch information
byjtew committed Jul 3, 2023
1 parent d80ed92 commit e616087
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions include/graphblas/reference/blas3.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1327,24 +1327,26 @@ namespace grb {
for( auto i = start_row; i < end_row; ++i ) {
auto mask_k = mask_raw.col_start[ i ];
for( auto k = A_crs_raw.col_start[ i ]; k < A_crs_raw.col_start[ i + 1 ]; ++k ) {
auto k_col = A_crs_raw.row_index[ k ];
const auto j = A_crs_raw.row_index[ k ];

// Increment the mask pointer until we find the right column, or a lower column (since the storage withing a row is sorted in a descending order)
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > k_col ) {
while( mask_k < mask_raw.col_start[ i + 1 ] && mask_raw.row_index[ mask_k ] > j ) {
_DEBUG_THREADESAFE_PRINT( "NEquals masked coordinate: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
mask_k++;
}

if( mask_raw.row_index[ mask_k ] < k_col || not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
mask_k++;
if( mask_k >= mask_raw.col_start[ i + 1 ] ) {
_DEBUG_THREADESAFE_PRINT( "No value left for this column\n" );
break;
}
if( mask_raw.row_index[ mask_k ] < j || not MaskHasValue< MaskType >( mask_raw, mask_k ).value ) {
_DEBUG_THREADESAFE_PRINT( "Skip masked value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
continue;
}

_DEBUG_THREADESAFE_PRINT( "Found masked value at: ( " + std::to_string( i ) + ";" + std::to_string( mask_raw.row_index[ mask_k ] ) + " )\n" );
// Get A value
const auto a_val_before = A_crs_raw.values[ k ];
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( k_col ) + " ) = " + std::to_string( a_val_before ) + "\n" );
_DEBUG_THREADESAFE_PRINT( "A( " + std::to_string( i ) + ";" + std::to_string( j ) + " ) = " + std::to_string( a_val_before ) + "\n" );
// Compute the fold for this coordinate
local_rc = local_rc ? local_rc : grb::apply< descr >( A_crs_raw.values[ k ], a_val_before, x, op );
local_rc = local_rc ? local_rc : grb::apply< descr >( A_ccs_raw.values[ k ], a_val_before, x, op );
Expand Down

0 comments on commit e616087

Please sign in to comment.