Skip to content

Commit

Permalink
[Operators] Support matmul with NT layout (#496)
Browse files Browse the repository at this point in the history
Closes #474
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 20, 2024
1 parent bc5b54e commit 8fc6de3
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 91 deletions.
12 changes: 10 additions & 2 deletions python/hidet/graph/frontend/torch/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import operator
import functools
import torch

from hidet.graph.tensor import Tensor, from_torch, ones_like, randn
from hidet.graph import ops
from hidet.utils import same_list
Expand Down Expand Up @@ -159,9 +160,16 @@ def max_pool3d(x: Tensor, kernel_size, stride, padding=0, dilation=1, ceil_mode=

@register_function(torch.nn.functional.linear)
def linear(x: Tensor, weight: Tensor, bias: Optional[Tensor], weight_is_transposed=False):
from hidet import float16

if len(weight.shape) > 1 and not weight_is_transposed:
weight = ops.transpose(weight, [1, 0])
y = ops.matmul(x, weight)
if weight.dtype == float16:
y = ops.matmul_nt(x, weight)
else:
weight = ops.transpose(weight, [1, 0])
y = ops.matmul(x, weight)
else:
y = ops.matmul(x, weight)
if bias is not None:
y = y + bias
return y
Expand Down
22 changes: 16 additions & 6 deletions python/hidet/graph/frontend/torch/register_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,27 @@ def __call__(self, x: Tensor) -> Tensor:
class HidetLinear(HidetModule):
def __init__(self, torch_module: torch.nn.Module):
super().__init__(torch_module)
self.can_use_nt_matmul = torch_module.weight.dtype == torch.float16
steal = dynamo_config['steal_weights']
self.transposed_weight = ops.transpose(self.param('weight', steal=steal), [1, 0])
self.torch_params['weight'] = None
self.hidet_params['weight'] = None
if not self.can_use_nt_matmul:
self.transposed_weight = ops.transpose(self.param('weight', steal=steal), [1, 0])
self.torch_params['weight'] = None
self.hidet_params['weight'] = None
else:
self.transposed_weight = None
torch.cuda.empty_cache()

def __call__(self, x: Tensor) -> Tensor:
assert isinstance(self.mod, torch.nn.Linear)
return reg_funcs.linear(
x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True), weight_is_transposed=True
)
if self.can_use_nt_matmul:
return reg_funcs.linear(
x=x, weight=self.param('weight'), bias=self.param('bias', optional=True), weight_is_transposed=False
)
else:
assert self.transposed_weight is not None
return reg_funcs.linear(
x=x, weight=self.transposed_weight, bias=self.param('bias', optional=True), weight_is_transposed=True
)


@register_module(torch.nn.BatchNorm2d)
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=redefined-builtin
from .matmul import batch_matmul, matmul, matmul_x86, matmul_cublas
from .matmul import batch_matmul, matmul, matmul_x86, matmul_cublas, matmul_nt
from .conv1d import conv1d, conv1d_gemm
from .conv1d_transpose import conv1d_transpose
from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .matmul import matmul, MatmulOp, MatmulTask
from .matmul import matmul, MatmulOp, MatmulTask, matmul_nt
from .batch_matmul import batch_matmul, BatchMatmulOp, BatchMatmulTask
from .matmul_cublas import matmul_cublas
from . import resolve
Expand Down
17 changes: 12 additions & 5 deletions python/hidet/graph/ops/matmul/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,25 @@


class MatmulTask(Task):
def __init__(self, a: TensorInput, b: TensorInput):
def __init__(self, a: TensorInput, b: TensorInput, transpose_b: bool = False):
from hidet.ir.compute import cops

c = cops.matmul(a, b, allow_1d=True)
c = cops.matmul(a, b, allow_1d=True, ta=False, tb=transpose_b)
super().__init__(name='matmul', inputs=[a, b], outputs=[c])


class MatmulOp(Operator):
def __init__(self, a: Tensor, b: Tensor, require_prologue=False):
task = MatmulTask(input_like(a, 'a'), input_like(b, 'b'))
super().__init__(inputs=[a, b], attributes={'require_prologue': require_prologue}, task=task)
def __init__(self, a: Tensor, b: Tensor, require_prologue=False, transpose_b: bool = False):
assert a.dtype == b.dtype, f"expected mat1 and mat2 to have the same shape, but got {a.dtype} != {b.dtype}"
task = MatmulTask(input_like(a, 'a'), input_like(b, 'b'), transpose_b=transpose_b)
super().__init__(
inputs=[a, b], attributes={'require_prologue': require_prologue, 'transpose_b': transpose_b}, task=task
)


def matmul(a: Tensor, b: Tensor, require_prologue=False) -> Tensor:
return MatmulOp(a, b, require_prologue=require_prologue).outputs[0]


def matmul_nt(a: Tensor, b: Tensor, require_prologue=False) -> Tensor:
return MatmulOp(a, b, require_prologue=require_prologue, transpose_b=True).outputs[0]
Loading

0 comments on commit 8fc6de3

Please sign in to comment.