Skip to content

Commit

Permalink
Fixed tail loop for swizzle-A for arbitrary M & K
Browse files Browse the repository at this point in the history
  • Loading branch information
Serge45 committed Jan 7, 2025
1 parent 7a1e518 commit 3f3ebf6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 12 deletions.
5 changes: 0 additions & 5 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2639,11 +2639,6 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ):
globalReadMode2nd = 2 if (((tensorParameters2nd["glvw"] * tensorParameters2nd["bpeGR"]) < 4) or \
kernel["tailLoopOpt"] == False) else 0

# if we have swizzled A or B, then size-K is already guarded, we don't have to used guarded-k GR again
hasSwizzled = tensorParametersA["isSwizzled"] or tensorParametersB["isSwizzled"]
globalReadMode1st = 0 if hasSwizzled else globalReadMode1st
globalReadMode2nd = 0 if hasSwizzled else globalReadMode2nd

module.addComment1("Update M0 for DTLDS")
moduleTmp = self.directToLdsM0Update(kernel, 1, tensorParameters1st)
module.add(replaceHolder(moduleTmp, 0))
Expand Down
12 changes: 5 additions & 7 deletions tensilelite/Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -5366,13 +5366,11 @@ def mfmaIter(self, kernel, tPA, tPB, u, innerUnroll, vregSetIdx, unrollLoopIdx =
if kernel["LocalSplitU"] > 1:
shiftK.add(SMinI32(dst=sgpr(loopCntSgpr), src0=sgpr(loopCounterName), src1=sgpr("LSUTailLoopOffset"), comment="check lsu bound"))
shiftK.add(VCmpGEI32(dst=sgpr(tmpSgprX2, self.states.laneSGPRCount), src0=vgpr(kReg), src1=sgpr(loopCntSgpr), comment="check K index >= Size L"))
# if A is swizzled, then no need to do this since MatB will be set to 0
if not tPA["isSwizzled"]:
for bk in range(0, vgprPerSet0Group):
for a in range(0, kernel["MIWaveTileA"]):
for iui in range(0, innerUnroll):
aStr = vgpr(self.generateSrcStrForMFMA(kernel, tPA, innerUnroll, vregSetIdx, vgprPerInputA, m, u, iui, a, bk=bk + group * vgprPerSet0Group), 1)
shiftK.add(VCndMaskB32(dst=aStr, src0=aStr, src1=hex(0), src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="set 0 if K_idx >= sizeL"))
for bk in range(0, vgprPerSet0Group):
for a in range(0, kernel["MIWaveTileA"]):
for iui in range(0, innerUnroll):
aStr = vgpr(self.generateSrcStrForMFMA(kernel, tPA, innerUnroll, vregSetIdx, vgprPerInputA, m, u, iui, a, bk=bk + group * vgprPerSet0Group), 1)
shiftK.add(VCndMaskB32(dst=aStr, src0=aStr, src1=hex(0), src2=sgpr(tmpSgprX2, self.states.laneSGPRCount), comment="set 0 if K_idx >= sizeL"))

if kernel["ProblemType"]["Sparse"] == 2 and numMIInput//8 >= 1:
shiftK.add(vectorStaticRemainder(dummy, kReg, "Serial", kernel["WavefrontSize"], tmpVgpr, tmpSgprInfo))
Expand Down

0 comments on commit 3f3ebf6

Please sign in to comment.