Skip to content

Commit

Permalink
Fixed segfault when using GuardPageBack for swizzle-A
Browse files Browse the repository at this point in the history
  • Loading branch information
Serge45 committed Jan 15, 2025
1 parent c5c4332 commit 63114c7
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions tensilelite/Tensile/Source/client/source/DataInitialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,13 +518,15 @@ namespace TensileLite
void* dst,
void* src,
size_t totalElements,
hipMemcpyKind kind)
hipMemcpyKind kind,
ptrdiff_t customPadding = -1)
{
ptrdiff_t dPadding = totalElements - descriptor.totalAllocatedElements();
const ptrdiff_t dPadding = (customPadding == -1) ? totalElements - descriptor.totalAllocatedElements() : customPadding;
const size_t numElementsToCopy = (customPadding == -1) ? descriptor.totalAllocatedElements() : (descriptor.totalAllocatedElements() + customPadding);
uint8_t* dstOffset = (uint8_t*)dst + (dPadding * descriptor.elementBytes());
HIP_CHECK_EXC(hipMemcpy(dstOffset,
src,
descriptor.elementBytes() * descriptor.totalAllocatedElements(),
descriptor.elementBytes() * numElementsToCopy,
kind));
return dstOffset;
}
Expand Down Expand Up @@ -601,6 +603,17 @@ namespace TensileLite
return stream;
}

size_t getSwizzledTensorNumAllocatedElements(const TensorDescriptor &desc, size_t miM, size_t miK, size_t packK)
{
const auto k = desc.sizes()[0];
const auto m = desc.sizes()[1];
const auto b = desc.sizes()[2];
const auto swizzleK = miK * packK;
const auto paddedM = (m + miM - 1) / miM * miM;
const auto paddedK = (k + swizzleK - 1) / swizzleK * swizzleK;
return paddedM * paddedK * b;
}

double DataInitialization::GetRepresentativeBetaValue(po::variables_map const& args)
{
auto argValue = args["init-beta"].as<int>();
Expand Down Expand Up @@ -692,23 +705,17 @@ namespace TensileLite
auto& pristine = m_vdata[i].pristine[dataType];
pristine.initDescriptor.resize(1);

//TODO: support more swizzle types
constexpr size_t MiM = 16;
constexpr size_t MiK = 16;
constexpr size_t PackK = 2;
constexpr size_t SwizzleK = MiK * PackK;
auto numAllocatedElements = problem.tensors()[i].totalAllocatedElements();
auto numAllocatedBytes = problem.tensors()[i].totalAllocatedBytes();

if ((problem.swizzleTensorA() && i == ContractionProblemGemm::TENSOR::A)
|| (problem.swizzleTensorB() && i == ContractionProblemGemm::TENSOR::B))
{
auto& desc = problem.tensors()[i];
auto unrolledSize = desc.sizes()[0];
auto tiledSize = desc.sizes()[1];
unrolledSize = (unrolledSize / SwizzleK + !!(unrolledSize % SwizzleK)) * SwizzleK;
tiledSize = (tiledSize / MiM + !!(tiledSize % MiM)) * MiM;
numAllocatedElements = unrolledSize * tiledSize;
//TODO: support more swizzle types
constexpr size_t MiM = 16;
constexpr size_t MiK = 16;
constexpr size_t PackK = 2;
numAllocatedElements = getSwizzledTensorNumAllocatedElements(problem.tensors()[i], MiM, MiK, PackK);
numAllocatedBytes = numAllocatedElements * GetElementSize(dataType);
}

Expand Down Expand Up @@ -1251,6 +1258,15 @@ namespace TensileLite
else if(m_curBoundsCheck == BoundsCheckMode::GuardPageBack)
{
padding = pUnit.maxElements - problem.tensors()[i].totalAllocatedElements();

if((problem.swizzleTensorA() && i == ContractionProblemGemm::TENSOR::A)
|| (problem.swizzleTensorB() && i == ContractionProblemGemm::TENSOR::B))
{
constexpr size_t MiM = 16;
constexpr size_t MiK = 16;
constexpr size_t PackK = 2;
padding = pUnit.maxElements - getSwizzledTensorNumAllocatedElements(problem.tensors()[i], MiM, MiK, PackK);
}
}
padding *= DataTypeInfo::Get(problem.tensors()[i].dataType()).elementSize;
uint8_t* offset = (uint8_t*)pUnit.gpuInput.current.get();
Expand Down Expand Up @@ -1581,24 +1597,35 @@ namespace TensileLite
if(it != m_vdata[i].pristine.end())
{
auto& p = it->second;
ptrdiff_t swizzlePadding{-1};

if(problem.swizzleTensorA() && i == ContractionProblemGemm::TENSOR::A
|| (problem.swizzleTensorB() && i == ContractionProblemGemm::TENSOR::B))
{
swizzlePadding = getSwizzledTensorNumAllocatedElements(desc, 16, 16, 2) - desc.totalAllocatedElements();
}

if(kind == hipMemcpyHostToHost)
ptr = copyNaNInputBuffers(desc,
p.cpuInput.current.get(),
p.cpuInput.valid.get(),
p.maxElements,
kind);
kind,
swizzlePadding);
else if(kind == hipMemcpyHostToDevice)
ptr = copyNaNInputBuffers(desc,
p.gpuInput.current.get(),
p.cpuInput.valid.get(),
p.maxElements,
kind);
kind,
swizzlePadding);
else if(kind == hipMemcpyDeviceToDevice)
ptr = copyNaNInputBuffers(desc,
p.gpuInput.current.get(),
p.gpuInput.valid.get(),
p.maxElements,
kind);
kind,
swizzlePadding);
ptrs.push_back(ptr);
batchPtrs.push_back(p.getInputByKind(kind).batch.get());
maxElements.push_back(p.maxElements);
Expand Down

0 comments on commit 63114c7

Please sign in to comment.