Skip to content

Commit

Permalink
Adding ROCm support [AMD official] (state-spaces#359)
Browse files Browse the repository at this point in the history
* Current progress

* Current progress

* FWD kernel compiles

* Current progress: compiles and imports, not tested

* Adapting tests

* working benchmark for torch hipify

* Removing hard-coded device typos

* all dtypes support

* fixed parallel compile issue

* use optimized causal_conv1d if available

* april18 perf bench script

* april18 perf bench script

* Delete csrc/selective_scan/selective_scan_fwd_kernel_minimal.cuh

* reverted benchmark

* reverted changes to base iteration 1

* removed files not in base

* Ported bwd changes (partial)

* Backward working fp32

* all dtypes with bwd

* gitignore hipfied files

* rocm cond and move max min to common

* triton autotune conditional

* Unifying setup.py (in progress)

* triton conditional autotune configs

* some more conditional compiles

* Setup.py functional

* Functional

* Minmax changes

* reduce repeatibility

* Removed extra comments

* fix template error

* Update csrc/selective_scan/reverse_scan.cuh

Co-authored-by: Jeff Daily <[email protected]>

* restore permissions to base

* permission for gitignore and readme

* warp size based on code review

* Adding ifndef + warnings for dynamic memory size adjustment

* minor chnages to setup

* fall back for warp size conditional

* patch method updated

* Minor stylistic changes + an extra warning about patching

* 4096 knloads patch

* Cleanup, conditional kernel launch parameters

* Flexible warp size

* Fix warp size to 32 for CUDA

---------

Co-authored-by: Arseny Moskvichev <[email protected]>
Co-authored-by: ajassani <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: Adeem Jassani <[email protected]>
Co-authored-by: Gabe Weisz <[email protected]>
Co-authored-by: ajassani <[email protected]>
Co-authored-by: Jeff Daily <[email protected]>
Co-authored-by: root <[email protected]>
  • Loading branch information
9 people authored Jun 18, 2024
1 parent c2568f5 commit 3c77dcf
Show file tree
Hide file tree
Showing 11 changed files with 479 additions and 142 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
*.egg-info/
build/
**.so
*.hip
*_hip.*
13 changes: 13 additions & 0 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ Mamba is a new state space model architecture showing promising performance on i
It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).

## Prerequisites

### Patching ROCm

If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.

1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.

2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
```bash
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
```

## Installation

- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
Expand Down
28 changes: 21 additions & 7 deletions csrc/selective_scan/reverse_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@

#pragma once

#include <cub/config.cuh>

#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
#include <cub/block/block_raking_layout.cuh>
// #include <cub/detail/uninitialized_copy.cuh>
#ifndef USE_ROCM
#include <cub/config.cuh>

#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>
#include <cub/block/block_raking_layout.cuh>
// #include <cub/detail/uninitialized_copy.cuh>
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "uninitialized_copy.cuh"

/**
Expand Down Expand Up @@ -46,6 +51,7 @@ __device__ __forceinline__ T ThreadReverseScanInclusive(
inclusive = scan_op(inclusive, input[i]);
output[i] = inclusive;
}
return inclusive;
}

/**
Expand Down Expand Up @@ -89,7 +95,15 @@ struct WarpReverseScan {
//---------------------------------------------------------------------

/// Whether the logical warp size and the PTX warp size coincide
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0));

// In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size()
// While in cub, it's defined as a macro that takes a redundant unused argument.
#ifndef USE_ROCM
#define WARP_THREADS CUB_WARP_THREADS(0)
#else
#define WARP_THREADS HIPCUB_WARP_THREADS
#endif
static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS);
/// The number of warp scan steps
static constexpr int STEPS = cub::Log2<LOGICAL_WARP_THREADS>::VALUE;
static_assert(LOGICAL_WARP_THREADS == 1 << STEPS);
Expand Down
111 changes: 75 additions & 36 deletions csrc/selective_scan/selective_scan_bwd_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#include <ATen/cuda/Atomic.cuh> // For atomicAdd on complex

#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_reduce.cuh>
#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_reduce.cuh>
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "selective_scan.h"
#include "selective_scan_common.h"
Expand All @@ -33,7 +38,7 @@ struct Selective_Scan_bwd_kernel_traits {
static constexpr int kNItems = kNItems_;
static constexpr int kNBytes = sizeof(input_t);
static_assert(kNBytes == 2 || kNBytes == 4);
static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
static_assert(kNItems % kNElts == 0);
static constexpr int kNLoads = kNItems / kNElts;
static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
Expand Down Expand Up @@ -61,12 +66,13 @@ struct Selective_Scan_bwd_kernel_traits {
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
using BlockReduceComplexT = cub::BlockReduce<complex_t, kNThreads>;
using BlockExchangeT = cub::BlockExchange<float, kNThreads, !kIsComplex ? kNItems : kNItems * 2>;
static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
sizeof(typename BlockLoadVecT::TempStorage),
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
sizeof(typename BlockStoreT::TempStorage),
sizeof(typename BlockStoreVecT::TempStorage)});

static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
sizeof(typename BlockLoadVecT::TempStorage),
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
sizeof(typename BlockStoreT::TempStorage),
sizeof(typename BlockStoreVecT::TempStorage)});
static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage);
static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage);
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage);
Expand Down Expand Up @@ -263,12 +269,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
// Initialize running total
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f);
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
Ktraits::BlockScanT(smem_scan).InclusiveScan(
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
);
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f);
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
);
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
Expand Down Expand Up @@ -297,11 +303,11 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
// Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower
if constexpr (kIsVariableB || kIsVariableC) {
if constexpr (kIsVariableB) {
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals);
}
if constexpr (kIsVariableC) {
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals);
}
const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x;
weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x;
Expand All @@ -316,13 +322,13 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
}
if constexpr (!kIsVariableB || !kIsVariableC) {
float2 dA_dBC_val = make_float2(dA_val, dBC_val);
dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
dA_val = dA_dBC_val.x;
if (threadIdx.x == 0) {
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx];
}
} else {
dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
dA_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val);
}
if (threadIdx.x == 0) {
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
Expand Down Expand Up @@ -356,12 +362,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
// Initialize running total
scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
Ktraits::BlockScanT(smem_scan).InclusiveScan(
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
);
scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
SSMScanPrefixCallbackOp<weight_t> postfix_op(running_postfix);
Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
typename Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan(
thread_reverse_data, thread_reverse_data, SSMScanOp<weight_t>(), postfix_op
);
if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; }
Expand Down Expand Up @@ -397,7 +403,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
dB_vals_f[i * 2] = dB_vals[i].real_;
dB_vals_f[i * 2 + 1] = dB_vals[i].imag_;
}
Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
typename Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f);
}
if constexpr (kIsVariableC) {
#pragma unroll
Expand All @@ -406,7 +412,7 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
dC_vals_f[i * 2 + 1] = dC_vals[i].imag_;
}
auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1;
Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
typename Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f);
}
const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x;
float *dB_cur = reinterpret_cast<float *>(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x;
Expand All @@ -421,14 +427,14 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
}
if constexpr (!kIsVariableB || !kIsVariableC) {
float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_);
dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
dA_dBC_val = typename Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val);
dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y);
dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w);
if (threadIdx.x == 0) {
smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx];
}
} else {
dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
dA_val = typename Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val);
}
if (threadIdx.x == 0) {
smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx];
Expand Down Expand Up @@ -465,12 +471,12 @@ void selective_scan_bwd_kernel(SSMParamsBwd params) {
Cvar -= kChunkSize * (!kIsComplex ? 1 : 2);
}
if (params.dD_ptr != nullptr) {
dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
dD_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val);
if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); }
}
if (params.ddelta_bias_ptr != nullptr) {
__syncthreads();
ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
ddelta_bias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val);
if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); }
}
for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
Expand Down Expand Up @@ -499,13 +505,24 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {
// using Ktraits = Selective_Scan_bwd_kernel_traits<kNThreads, kNItems, true, kIsVariableB, kIsVariableC, kDeltaSoftplus, kHasZ, input_t, weight_t>;
// TODO: check this
constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t);
// printf("smem_size = %d\n", kSmemSize);

dim3 grid(params.batch, params.dim);

auto kernel = &selective_scan_bwd_kernel<Ktraits>;

if (kSmemSize >= 48 * 1024) {

#ifndef USE_ROCM
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
#else
C10_CUDA_CHECK(cudaFuncSetAttribute(
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
std::cerr << "Warning (selective_scan_fwd_kernel): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
#endif

}

kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand All @@ -517,15 +534,37 @@ void selective_scan_bwd_launch(SSMParamsBwd &params, cudaStream_t stream) {

template<typename input_t, typename weight_t>
void selective_scan_bwd_cuda(SSMParamsBwd &params, cudaStream_t stream) {
if (params.seqlen <= 128) {
selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 256) {
selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 512) {
selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 1024) {
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
} else {
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);

#ifndef USE_ROCM
constexpr int warp_size = 32;
#else
constexpr int warp_size = rocprim::warp_size();
#endif

if (warp_size == 32) {
if (params.seqlen <= 128) {
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 256) {
selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 512) {
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 1024) {
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
} else {
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
}
}
#ifdef USE_ROCM
else {
if (params.seqlen <= 256) {
selective_scan_bwd_launch<64, 4, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 512) {
selective_scan_bwd_launch<64, 8, input_t, weight_t>(params, stream);
} else if (params.seqlen <= 1024) {
selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream);
} else {
selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream);
}
}
#endif
}
Loading

0 comments on commit 3c77dcf

Please sign in to comment.