Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Parallel Primitive Should be enhanced to improve the performance for irregular shapes #209

Open
LeiWang1999 opened this issue Oct 2, 2024 · 0 comments

Comments

@LeiWang1999
Copy link
Contributor

To reproduce a worse case:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas import tvm as tvm
from tvm import tl
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
    MatmulFineGrainScheduler,
)

import torch
import torch.backends

torch.manual_seed(0)


def assert_matmul_fine_grained_apply_config_correctness(
    M,
    N,
    K,
    trans_A=False,
    trans_B=True,
    in_dtype="float16",
    out_dtype="float16",
    accum_dtype="float16",
    block_row_warps=1,
    block_col_warps=1,
    warp_row_tiles=16,
    warp_col_tiles=16,
    chunk=32,
    num_stages=2,
    enable_rasterization=False,
):

    matmul = MatmulFineGrainScheduler(
        M=M,
        N=N,
        K=K,
        trans_A=trans_A,
        trans_B=trans_B,
        in_dtype=in_dtype,
        out_dtype=out_dtype,
        accum_dtype=accum_dtype,
    ).apply_config(
        block_row_warps=block_row_warps,
        block_col_warps=block_col_warps,
        warp_row_tiles=warp_row_tiles,
        warp_col_tiles=warp_col_tiles,
        chunk=chunk,
        num_stages=num_stages,
        enable_rasterization=enable_rasterization,
    )

    mod, params = tl.lower(matmul)
    src_code = mod.imported_modules[0].get_source()

    # src_code is the generated cuda source
    assert src_code is not None

    A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
    B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
    C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))

    mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)

    mod(A, B, C)

    latency = mod.do_bench(mod.func, warmup=25)

    # Ensure that the latency is not None
    assert latency is not None

    # Get Reference Result
    ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
    torch.testing.assert_close(C, ref_c, rtol=1e-1, atol=1e-1)

assert_matmul_fine_grained_apply_config_correctness(
    768, 768, 768, False, True, "float16", "float16", "float16",
    block_row_warps=1, block_col_warps=3, warp_row_tiles=16, warp_col_tiles=16, chunk=128, num_stages=0,
)

The output is:

root@e01d939002c0:~/BitBLAS# /usr/bin/python /root/BitBLAS/debug/test_issue_parallel.py
Traceback (most recent call last):
  File "/root/BitBLAS/debug/test_issue_parallel.py", line 138, in <module>
    assert_matmul_fine_grained_apply_config_correctness(
  File "/root/BitBLAS/debug/test_issue_parallel.py", line 56, in assert_matmul_fine_grained_apply_config_correctness
    mod, params = tl.lower(matmul)
  File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/tl/engine.py", line 84, in lower
    mod = tl.transform.LayoutInference()(mod)
  File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/root/BitBLAS/bitblas/../3rdparty/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  56: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  55: tvm::transform::Pass::operator()(tvm::IRModule) const
  54: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  53: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  52: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS_2tl15LayoutInferenceEvEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
  51: tvm::tl::LayoutInferencer::Substitute(tvm::tir::PrimFunc)
  50: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  49: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  48: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  47: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::BlockNode const*)
  46: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::BlockNode const*)
  45: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  44: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  43: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  42: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  41: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  40: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  39: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  38: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  37: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  36: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  35: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  34: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  33: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  32: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  31: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  30: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  29: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  28: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  27: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  26: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  25: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  24: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  23: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  22: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  21: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::BlockNode const*)
  20: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::BlockNode const*)
  19: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  18: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  17: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
  16: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  15: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  14: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  13: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  12: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  11: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  10: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
  9: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  8: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  7: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9Ob
  6: tvm::runtime::ObjectPtr<tvm::runtime::Object> tvm::runtime::Array<tvm::tir::Stmt, void>::MapHelper<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}, tvm::tir::Stmt>(tvm::runtime::Object, tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})
  5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  4: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  3: tvm::tl::LayoutInferencer::VisitStmt_(tvm::tir::ForNode const*)
  2: tvm::tl::PartitionLoop(tvm::tir::For, tvm::tir::Var, tvm::arith::Analyzer*, tvm::tl::Fragment)
  1: tvm::tl::FragmentNode::Inverse() const
  0: tvm::tl::LayoutNode::Inverse() const
  File "/root/BitBLAS/3rdparty/tvm/src/tl/layout/layout.cc", line 205
InternalError: Check failed: (res->errors.empty()) is false: ["The iterations do not traverse full iter space", "Index mapping does not form a bijective transform."]

The problem lies in the parallel primitives:

for i in T.parallel(16):
 for j in T.parallel(128):
      A_shared[i, j] = A[x, y]

while the thread num is 96, which is not divisible by 16x128.

However, this tile is the most efficient among the tile configurations and can be tensorized using our raw BitBlas TIR backend.

@LeiWang1999 LeiWang1999 changed the title [Feature Request] Parallel Primitive Should be enhanced to improve the performance for small shapes. [Feature Request] Parallel Primitive Should be enhanced to improve the performance for irregular shapes. Oct 2, 2024
@LeiWang1999 LeiWang1999 changed the title [Feature Request] Parallel Primitive Should be enhanced to improve the performance for irregular shapes. [Feature Request] Parallel Primitive Should be enhanced to improve the performance for irregular shapes Oct 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant