Skip to content

Commit

Permalink
hip pinned allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
apwojcik committed Feb 3, 2025
1 parent b6ab00f commit 1f3a3fb
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/migraphx/migraphx_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ void* MIGraphXExternalAllocator::Reserve(size_t size) {
return p;
}

void* HIPPinnedAllocator::Alloc(size_t size) {
void* MIGraphXPinnedAllocator::Alloc(size_t size) {
void* p = nullptr;
if (size > 0) {
HIP_CALL_THROW(hipHostMalloc((void**)&p, size));
}
return p;
}

void HIPPinnedAllocator::Free(void* p) {
void MIGraphXPinnedAllocator::Free(void* p) {
HIP_CALL_THROW(hipHostFree(p));
}

Expand Down
11 changes: 5 additions & 6 deletions onnxruntime/core/providers/migraphx/migraphx_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,16 @@ class MIGraphXExternalAllocator : public MIGraphXAllocator {
std::unordered_set<void*> reserved_;
};

// TODO: add a default constructor
class HIPPinnedAllocator : public IAllocator {
class MIGraphXPinnedAllocator final : public IAllocator {
public:
HIPPinnedAllocator(int device_id, const char* name)
MIGraphXPinnedAllocator(const int device_id, const char* name)
: IAllocator(
OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator,
OrtMemoryInfo(name, OrtDeviceAllocator,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(device_id)),
device_id, OrtMemTypeCPUOutput)) {}

virtual void* Alloc(size_t size) override;
virtual void Free(void* p) override;
void* Alloc(size_t size) override;
void Free(void* p) override;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ std::vector<AllocatorPtr> MIGraphXExecutionProvider::CreatePreferredAllocators()
[](OrtDevice::DeviceId device_id) { return std::make_unique<MIGraphXAllocator>(device_id, onnxruntime::CUDA); }, info_.device_id);
AllocatorCreationInfo pinned_allocator_info(
[](OrtDevice::DeviceId device_id) {
return std::make_unique<HIPPinnedAllocator>(device_id, onnxruntime::CUDA_PINNED);
return std::make_unique<MIGraphXPinnedAllocator>(device_id, onnxruntime::CUDA_PINNED);
},
0);
return std::vector<AllocatorPtr>{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX {
}

std::unique_ptr<IAllocator> CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override {
return std::make_unique<HIPPinnedAllocator>(device_id, name);
return std::make_unique<MIGraphXPinnedAllocator>(device_id, name);
}

void MIGraphXMemcpy_HostToDevice(void* dst, const void* src, size_t count) override {
Expand Down

0 comments on commit 1f3a3fb

Please sign in to comment.