Skip to content

Commit

Permalink
Refactor SM90 radix_sort tuning (#3125)
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber authored Dec 11, 2024
1 parent 346a618 commit 29ba731
Showing 1 changed file with 15 additions and 49 deletions.
64 changes: 15 additions & 49 deletions cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -761,20 +761,19 @@ struct policy_hub
/// SM90
struct Policy900 : ChainedPolicy<900, Policy900, Policy800>
{
enum
{
PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5,
SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
ONESWEEP = true,
ONESWEEP_RADIX_BITS = 8,
OFFSET_64BIT = sizeof(OffsetT) == 8 ? 1 : 0,
FLOAT_KEYS = std::is_same<KeyT, float>::value ? 1 : 0,
};
static constexpr bool ONESWEEP = true;
static constexpr int ONESWEEP_RADIX_BITS = 8;

using HistogramPolicy = AgentRadixSortHistogramPolicy<128, 16, 1, KeyT, ONESWEEP_RADIX_BITS>;
using ExclusiveSumPolicy = AgentRadixSortExclusiveSumPolicy<256, ONESWEEP_RADIX_BITS>;

private:
static constexpr int PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5;
static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5;
static constexpr int SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5;
static constexpr int OFFSET_64BIT = sizeof(OffsetT) == 8 ? 1 : 0;
static constexpr int FLOAT_KEYS = ::cuda::std::is_same<KeyT, float>::value ? 1 : 0;

using OnesweepPolicyKey32 = AgentRadixSortOnesweepPolicy<
384,
KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS
Expand All @@ -796,11 +795,11 @@ struct policy_hub
RADIX_SORT_STORE_DIRECT,
ONESWEEP_RADIX_BITS>;

using OnesweepLargeKeyPolicy = //
::cuda::std::_If<sizeof(KeyT) == 4, OnesweepPolicyKey32, OnesweepPolicyKey64>;
using OnesweepLargeKeyPolicy = ::cuda::std::_If<sizeof(KeyT) == 4, OnesweepPolicyKey32, OnesweepPolicyKey64>;

using OnesweepSmallKeyPolicySizes =
sm90_small_key_tuning<sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>;

using OnesweepSmallKeyPolicySizes = //
detail::radix::sm90_small_key_tuning<sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>;
using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy<
OnesweepSmallKeyPolicySizes::threads,
OnesweepSmallKeyPolicySizes::items,
Expand All @@ -810,42 +809,9 @@ struct policy_hub
BLOCK_SCAN_RAKING_MEMOIZE,
RADIX_SORT_STORE_DIRECT,
8>;
using OnesweepPolicy = //
::cuda::std::_If<sizeof(KeyT) < 4, //
OnesweepSmallKeyPolicy, //
OnesweepLargeKeyPolicy>;

using ScanPolicy =
AgentScanPolicy<512,
23,
OffsetT,
BLOCK_LOAD_WARP_TRANSPOSE,
LOAD_DEFAULT,
BLOCK_STORE_WARP_TRANSPOSE,
BLOCK_SCAN_RAKING_MEMOIZE>;

using DownsweepPolicy = AgentRadixSortDownsweepPolicy<
512,
23,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MATCH,
BLOCK_SCAN_WARP_SCANS,
PRIMARY_RADIX_BITS>;

using AltDownsweepPolicy = AgentRadixSortDownsweepPolicy<
(sizeof(KeyT) > 1) ? 256 : 128,
47,
DominantT,
BLOCK_LOAD_TRANSPOSE,
LOAD_DEFAULT,
RADIX_RANK_MEMOIZE,
BLOCK_SCAN_WARP_SCANS,
PRIMARY_RADIX_BITS - 1>;

using UpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 23, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS>;
using AltUpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 47, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1>;
public:
using OnesweepPolicy = ::cuda::std::_If<sizeof(KeyT) < 4, OnesweepSmallKeyPolicy, OnesweepLargeKeyPolicy>;

using SingleTilePolicy = AgentRadixSortDownsweepPolicy<
256,
Expand Down

0 comments on commit 29ba731

Please sign in to comment.