Skip to content

Commit

Permalink
adding tests workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoyaoding committed Oct 28, 2022
1 parent b020433 commit 4f58872
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 45 deletions.
25 changes: 20 additions & 5 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,41 @@ name: Python Tests

on:
push:
# branches: [main]
branches: [main]
pull_request:

jobs:
build:
tests:
runs-on: [self-hosted, Linux, X64, gpu]
container:
image: nvidia/cuda:11.8.0-base-ubuntu20.04
image: nvidia/cuda:11.8.0-devel-ubuntu20.04
options: --gpus all
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
- name: Setup cmake
uses: jwlawson/[email protected]
with:
cmake-version: '3.19.x'
- name: Setup ccache
run: |
apt update && apt install -y ccache
- name: Build hidet
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run pytest
pip install -r requirements-dev.txt
bash scripts/build_wheel.sh
WHEEL=$(find ./scripts/ -maxdepth 1 -name '*.whl')
echo "Built wheel: $WHEEL"
pip install --no-deps --force-reinstall $WHEEL
- name: Run minimal tests
run: |
python -m pytest -v tests/minimal/test_add.py
- name: Run full tests
run: |
# stop the build if format is not correct
python -m pytest -v ./tests
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
recursive-include python/hidet/lib *.so
recursive-include python/hidet/include *.h
1 change: 0 additions & 1 deletion include/hidet/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <cstdlib>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <iostream>
#include <hidet/runtime/common.h>
#include <hidet/runtime/logging.h>
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def compile_source(src_path: str, out_lib_path: str, keep_ptx=False) -> None:
with tempfile.TemporaryDirectory() as working_dir:
result = subprocess.run(" ".join(command).split(), stderr=PIPE, stdout=PIPE, cwd=working_dir, check=False)
if result.returncode:
message = ''
message = "Command: " + " ".join(command) + "\n"
if result.stdout:
message += result.stdout.decode().strip() + '\n'
if result.stderr:
Expand All @@ -117,7 +117,7 @@ def compile_source(src_path: str, out_lib_path: str, keep_ptx=False) -> None:
out_lib_dir = os.path.dirname(out_lib_path)
ptx_path = os.path.join(working_dir, ptx_name)
target_ptx_path = os.path.join(out_lib_dir, ptx_name)
os.rename(ptx_path, target_ptx_path)
shutil.move(ptx_path, target_ptx_path)
raise CompilationFailed(src_path, message)
out_lib_dir = os.path.dirname(out_lib_path)
if keep_ptx:
Expand Down
3 changes: 0 additions & 3 deletions python/hidet/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from . import hidet_models
from . import torch_models
from . import onnx_models

from . import utils

from .utils import benchmark_func, check_unary, check_binary
8 changes: 6 additions & 2 deletions python/hidet/testing/onnx_models/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from hidet.utils import hidet_cache_file
from hidet.utils.transformers_utils import export_transformer_model_as_onnx
from hidet.utils.torch_utils import export_torchvision_model_as_onnx
from .model_blocks import get_bert_block, get_resnet50_block
from .operators import get_onnx_operator


def get_onnx_model(
Expand Down Expand Up @@ -77,10 +75,16 @@ def get_onnx_model(
]
return model_path, input_names, input_tensors
elif name.startswith('resnet50_'):
from .model_blocks import get_resnet50_block

return get_resnet50_block(name, batch_size=batch_size, precision=precision, **kwargs)
elif name.startswith('bert_'):
from .model_blocks import get_bert_block

return get_bert_block(name, batch_size=batch_size, precision=precision, **kwargs)
elif name.startswith('op_'):
from .operators import get_onnx_operator

return get_onnx_operator(name, batch_size, precision=precision)
else:
raise NotImplementedError('Can not recognize model {}'.format(name))
42 changes: 29 additions & 13 deletions python/hidet/testing/onnx_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,38 @@
import onnx
import hidet

try:
import torch
from torch import nn
except ImportError:
pass


def export_torch_to_onnx(
onnx_path: str,
model: nn.Module,
input_names: List[str],
inputs: List[torch.Tensor],
precision: Optional[str] = None,
nocache=False,
onnx_path: str, model, input_names: List[str], inputs, precision: Optional[str] = None, nocache=False
):
# onnx_path = hidet_cache_file('onnx', 'bert', f'{name}.onnx')
"""
Export a torch model to onnx.
Parameters
----------
onnx_path: str
Path to store the onnx file.
model: torch.nn.Module
The torch model to be exported.
input_names: List[str]
The names of the inputs in the exported onnx model.
inputs: Sequence[torch.Tensor]
The inputs to the model.
precision: Optional[str]
The precision of the exported onnx model. If None, the precision of the model is not changed.
Candidates: 'float16', 'float32'
nocache: bool
If True, the onnx model will be exported even if the onnx file already exists.
Returns
-------
(onnx_path, input_names, hidet_inputs): Tuple[str, List[str], List[hidet.Tensor]]
The path to the exported onnx model, the names of the inputs in the exported onnx model, and the inputs to the
exported onnx model.
"""

import torch

if nocache and os.path.exists(onnx_path):
os.remove(onnx_path)
precision_dict = {'float32': torch.float32, 'float16': torch.float16}
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/utils/transformers_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import os
import subprocess
import shutil


def export_transformer_model_as_onnx(
Expand Down Expand Up @@ -32,7 +33,7 @@ def export_transformer_model_as_onnx(
command = '{} -m transformers.onnx --model {} --feature {} {}'.format(sys.executable, model_name, feature, temp_dir)
print("Running '{}'".format(command))
subprocess.run(command.split(), check=True)
os.rename(os.path.join(temp_dir, 'model.onnx'), output_path)
shutil.move(os.path.join(temp_dir, 'model.onnx'), output_path)
print('Model saved at: {}'.format(output_path))


Expand Down
9 changes: 9 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# python test
pytest

# for models to test
torchvision
transformers

# check the correctness with onnxruntime
onnxruntime-gpu
15 changes: 9 additions & 6 deletions scripts/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ rm -rf build; mkdir build;
cd build; cmake ../..; make -j4; cd ..

# copy the built libraries and headers to python module
cp -r ./build/lib ../python/hidet
cp -r ../include ../python/hidet
cp ../setup.py ./setup.py
cp ../MANIFEST.in ./MANIFEST.in
cp -r ../python ./
cp -r ./build/lib ./python/hidet
cp -r ../include ./python/hidet

# build wheel
pip wheel --no-deps ..
pip wheel --no-deps .

# remove all intermediate directories
rm -rf ../python/hidet/hidet.egg-info
rm -rf ../python/hidet/lib
rm -rf ../python/hidet/include
rm -rf ./python
rm -rf ./build
rm ./setup.py
rm ./MANIFEST.in
28 changes: 19 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from setuptools import setup, find_packages
from glob import glob
from setuptools import setup, find_packages, Distribution


class BinaryDistribution(Distribution):
def has_ext_modules(self):
return True

def is_pure(self):
return False


setup(
name="hidet",
Expand All @@ -8,13 +18,12 @@
packages=find_packages(where='python'),
package_dir={"": "python"},
include_package_data=True,
package_data={
'hidet': [
'lib/*.so',
'include/**/*.h'
]
},
zip_safe=False,
# package_data={
# 'hidet': [
# *glob('lib/*.so'),
# *glob('include/**/*.h', recursive=True),
# ]
# },
install_requires=[
"onnx",
"numpy",
Expand All @@ -23,5 +32,6 @@
"nvtx",
"tabulate",
"astunparse"
]
],
distclass=BinaryDistribution,
)
17 changes: 14 additions & 3 deletions tests/graph/operators/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
import numpy as np
import pytest

import hidet
from hidet import ops
from hidet.testing import check_binary


@pytest.mark.parametrize(
"a_shape, b_shape, dtype", [[[1, 128, 128], [1, 128, 128], "float32"], [[1, 333, 444], [1, 444, 555], "float32"]]
"a_shape, b_shape, dtype", [[[1, 333, 444], [1, 444, 555], "float32"], [[1, 333, 444], [1, 444, 555], "float16"]]
)
@pytest.mark.parametrize('mma', ['simt', 'wmma', 'mma'])
def test_batch_matmul(a_shape, b_shape, dtype, mma):
mma2tolerance = {'simt': 1e-4, 'wmma': 0.05, 'mma': 0.05}
tol = mma2tolerance[mma]
if hidet.utils.cuda.query_compute_capability() < (8, 0) and mma in ['wmma', 'mma'] and dtype == 'float32':
pytest.skip('wmma and mma for float32 will triger hidet to use tf32, which is only supported on sm80 and above')
tolerance = {
('float16', 'simt'): 0.5,
('float16', 'wmma'): 0.5,
('float16', 'mma'): 0.5,
('float32', 'simt'): 1e-4,
('float32', 'wmma'): 0.05,
('float32', 'mma'): 0.05,
}
tol = tolerance[(dtype, mma)]
check_binary(
a_shape,
b_shape,
lambda x, y: np.matmul(x, y),
lambda x, y: ops.batch_matmul(x, y, mma=mma),
device='cuda',
dtype=dtype,
atol=tol,
rtol=tol,
Expand Down
7 changes: 7 additions & 0 deletions tests/ir/primitives/cuda/test_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def matmul_mma_tensor_core(config: MmaConfig):
],
)
def test_mma(config: MmaConfig):
if hidet.utils.cuda.query_compute_capability() < (8, 0):
if 'tf32' in [config.input_dtype, config.output_dtype]:
pytest.skip('tfloat32 tensor core is supported on device with sm80 or higher')
if 'bf16' in [config.input_dtype, config.output_dtype]:
pytest.skip('bfloat16 tensor core is supported on device with sm80 or higher')
if (config.m, config.n, config.k) in [(16, 8, 16)]:
pytest.skip('tensor core with shape m16n8k16 is supported on device with sm80 or higher')
ir_module = matmul_mma_tensor_core(config)
func = build_ir_module(
ir_module,
Expand Down
15 changes: 15 additions & 0 deletions tests/minimal/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
import numpy as np
import hidet


def test_add():
a = hidet.randn([10], device='cuda')
b = hidet.randn([10], device='cuda')
c = a + b
c_np = a.numpy() + b.numpy()
np.testing.assert_allclose(actual=c.numpy(), desired=c_np, atol=1e-5, rtol=1e-5)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 4f58872

Please sign in to comment.