-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptmized_ttir.txt
104 lines (104 loc) · 6.31 KB
/
optmized_ttir.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
module {
tt.func public @matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c31_i32 = arith.constant 31 : i32
%c63_i32 = arith.constant 63 : i32
%c127_i32 = arith.constant 127 : i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<32> : tensor<128x32xi32>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf32>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf32>
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c128_i32 = arith.constant 128 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.addi %arg3, %c127_i32 : i32
%2 = arith.divsi %1, %c128_i32 : i32
%3 = arith.addi %arg4, %c63_i32 : i32
%4 = arith.divsi %3, %c64_i32 : i32
%5 = arith.muli %4, %c8_i32 : i32
%6 = arith.divsi %0, %5 : i32
%7 = arith.muli %6, %c8_i32 : i32
%8 = arith.subi %2, %7 : i32
%9 = arith.cmpi slt, %8, %c8_i32 : i32
%10 = arith.select %9, %8, %c8_i32 : i32
%11 = arith.remsi %0, %10 : i32
%12 = arith.addi %7, %11 : i32
%13 = arith.remsi %0, %5 : i32
%14 = arith.divsi %13, %10 : i32
%15 = arith.muli %12, %c128_i32 : i32
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%17 = tt.splat %15 : (i32) -> tensor<128xi32>
%18 = arith.addi %17, %16 : tensor<128xi32>
%19 = tt.splat %arg3 : (i32) -> tensor<128xi32>
%20 = arith.remsi %18, %19 : tensor<128xi32>
%21 = arith.muli %14, %c64_i32 : i32
%22 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
%23 = tt.splat %21 : (i32) -> tensor<64xi32>
%24 = arith.addi %23, %22 : tensor<64xi32>
%25 = tt.splat %arg4 : (i32) -> tensor<64xi32>
%26 = arith.remsi %24, %25 : tensor<64xi32>
%27 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%28 = tt.expand_dims %20 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
%29 = tt.splat %arg6 : (i32) -> tensor<128x1xi32>
%30 = arith.muli %28, %29 : tensor<128x1xi32>
%31 = tt.expand_dims %27 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
%32 = tt.broadcast %30 : (tensor<128x1xi32>) -> tensor<128x32xi32>
%33 = tt.broadcast %31 : (tensor<1x32xi32>) -> tensor<128x32xi32>
%34 = arith.addi %32, %33 : tensor<128x32xi32>
%35 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>>
%36 = tt.addptr %35, %34 : tensor<128x32x!tt.ptr<f16>>, tensor<128x32xi32>
%37 = tt.expand_dims %27 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
%38 = tt.splat %arg7 : (i32) -> tensor<32x1xi32>
%39 = arith.muli %37, %38 : tensor<32x1xi32>
%40 = tt.expand_dims %26 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
%41 = tt.broadcast %39 : (tensor<32x1xi32>) -> tensor<32x64xi32>
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<32x64xi32>
%43 = arith.addi %41, %42 : tensor<32x64xi32>
%44 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<32x64x!tt.ptr<f16>>
%45 = tt.addptr %44, %43 : tensor<32x64x!tt.ptr<f16>>, tensor<32x64xi32>
%46 = arith.addi %arg5, %c31_i32 : i32
%47 = arith.divsi %46, %c32_i32 : i32
%48 = arith.truncf %cst_2 : tensor<128x32xf32> to tensor<128x32xf16>
%49 = arith.truncf %cst_1 : tensor<32x64xf32> to tensor<32x64xf16>
%50 = arith.muli %arg7, %c32_i32 : i32
%51 = tt.splat %50 : (i32) -> tensor<32x64xi32>
%52:3 = scf.for %arg9 = %c0_i32 to %47 step %c1_i32 iter_args(%arg10 = %cst_0, %arg11 = %36, %arg12 = %45) -> (tensor<128x64xf32>, tensor<128x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>) : i32 {
%70 = arith.muli %arg9, %c32_i32 : i32
%71 = arith.subi %arg5, %70 : i32
%72 = tt.splat %71 : (i32) -> tensor<1x32xi32>
%73 = arith.cmpi slt, %31, %72 : tensor<1x32xi32>
%74 = tt.broadcast %73 : (tensor<1x32xi1>) -> tensor<128x32xi1>
%75 = tt.load %arg11, %74, %48 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16>
%76 = tt.splat %71 : (i32) -> tensor<32x1xi32>
%77 = arith.cmpi slt, %37, %76 : tensor<32x1xi32>
%78 = tt.broadcast %77 : (tensor<32x1xi1>) -> tensor<32x64xi1>
%79 = tt.load %arg12, %78, %49 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
%80 = tt.dot %75, %79, %arg10 {allowTF32 = true} : tensor<128x32xf16> * tensor<32x64xf16> -> tensor<128x64xf32>
%81 = tt.addptr %arg11, %cst : tensor<128x32x!tt.ptr<f16>>, tensor<128x32xi32>
%82 = tt.addptr %arg12, %51 : tensor<32x64x!tt.ptr<f16>>, tensor<32x64xi32>
scf.yield %80, %81, %82 : tensor<128x64xf32>, tensor<128x32x!tt.ptr<f16>>, tensor<32x64x!tt.ptr<f16>>
}
%53 = arith.truncf %52#0 : tensor<128x64xf32> to tensor<128x64xf16>
%54 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32>
%55 = tt.splat %arg8 : (i32) -> tensor<128x1xi32>
%56 = arith.muli %55, %54 : tensor<128x1xi32>
%57 = tt.splat %arg2 : (!tt.ptr<f16>) -> tensor<128x1x!tt.ptr<f16>>
%58 = tt.addptr %57, %56 : tensor<128x1x!tt.ptr<f16>>, tensor<128x1xi32>
%59 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<64xi32>) -> tensor<1x64xi32>
%60 = tt.broadcast %58 : (tensor<128x1x!tt.ptr<f16>>) -> tensor<128x64x!tt.ptr<f16>>
%61 = tt.broadcast %59 : (tensor<1x64xi32>) -> tensor<128x64xi32>
%62 = tt.addptr %60, %61 : tensor<128x64x!tt.ptr<f16>>, tensor<128x64xi32>
%63 = tt.splat %arg3 : (i32) -> tensor<128x1xi32>
%64 = arith.cmpi slt, %54, %63 : tensor<128x1xi32>
%65 = tt.splat %arg4 : (i32) -> tensor<1x64xi32>
%66 = arith.cmpi slt, %59, %65 : tensor<1x64xi32>
%67 = tt.broadcast %64 : (tensor<128x1xi1>) -> tensor<128x64xi1>
%68 = tt.broadcast %66 : (tensor<1x64xi1>) -> tensor<128x64xi1>
%69 = arith.andi %67, %68 : tensor<128x64xi1>
tt.store %62, %53, %69 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16>
tt.return
}
}