Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CudaIpc 3/3]: p2p get-Zcopy #3894

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/mma_type.cpp
${NVFUSER_SRCS_DIR}/multidevice/communication.cpp
${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp
${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp
${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp
${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp
${NVFUSER_SRCS_DIR}/multidevice/executor.cpp
${NVFUSER_SRCS_DIR}/multidevice/utils.cpp
Expand Down Expand Up @@ -349,6 +351,7 @@ target_link_libraries(codegen_internal PUBLIC
${LIBCUPTI}
${TORCH_LIBRARIES}
dl
cuda
)

add_library(nvfuser_codegen SHARED $<TARGET_OBJECTS:codegen_internal>)
Expand Down
3 changes: 2 additions & 1 deletion csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ class Val;
f(Wait); \
f(Synchronize); \
f(StartCoalescing); \
f(EndCoalescing);
f(EndCoalescing); \
f(ShareMemHandles);

// Forward declarations for all Val and Expr types

Expand Down
69 changes: 58 additions & 11 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <instrumentation.h>
#include <ir/utils.h>
#include <multidevice/communication.h>
#include <multidevice/cuda_p2p.h>
#include <multidevice/utils.h>
#include <options.h>
#include <runtime/allocations.h>
Expand Down Expand Up @@ -405,6 +406,11 @@ void HostIrEvaluator::handle(PostOnStream* post_ir) {
}
}

void HostIrEvaluator::handle(ShareMemHandles* share_mem_handles) {
ipc_handle_cache_.exchangeHandles(
share_mem_handles->communications(), expr_evaluator_);
}

void HostIrEvaluator::handle(Communication* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
Expand All @@ -430,25 +436,66 @@ void HostIrEvaluator::handle(P2PCommunication* communication) {
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

const int64_t my_rank = communicator_->deviceId();
const auto dst = expr_evaluator_.evaluate(communication->dst()).as<int64_t>();
const auto src = expr_evaluator_.evaluate(communication->src()).as<int64_t>();
const bool is_sender = my_rank == src;
const bool is_receiver = my_rank == dst;
if (!(is_sender ^ is_receiver)) {
return;
}

CommunicatorBackend backend_type = communication->backend();
at::Tensor buffer =
getKnownTensorOrUndefined(communication->buffer(), expr_evaluator_);

works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
expr_evaluator_.evaluate(communication->peer()).as<int64_t>(),
communicator_->getWorld(),
buffer);
if (backend_type == CommunicatorBackend::kCuda) {
const P2pIpcHandle& ipc_handles =
ipc_handle_cache_.get(communication, expr_evaluator_);
const auto current_stream = static_cast<CUstream>(
c10::cuda::getCurrentCUDAStream(communicator_->local_rank()).stream());
if (is_receiver) {
getZcopy::RecvPost(
ipc_handles, buffer.numel() * buffer.element_size(), current_stream);
} else /*sender*/ {
getZcopy::SendPost(ipc_handles, current_stream);
}
} else {
works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
dst,
src,
communicator_->getWorld(),
buffer);
}
}

void HostIrEvaluator::handle(Wait* wait) {
Expr* communication = wait->communication();
NVF_ERROR(works_.find(communication) != works_.end(), "no wait req");
auto& work = works_.at(communication);
if (work != nullptr) {
work->wait();
auto* p2p_comm = dynamic_cast<P2PCommunication*>(communication);
if (p2p_comm && p2p_comm->backend() == CommunicatorBackend::kCuda) {
const auto src = expr_evaluator_.evaluate(p2p_comm->src()).as<int64_t>();
const auto dst = expr_evaluator_.evaluate(p2p_comm->dst()).as<int64_t>();
const int64_t my_rank = communicator_->deviceId();
if (my_rank == src && src != dst) {
const auto current_stream = static_cast<CUstream>(
c10::cuda::getCurrentCUDAStream(communicator_->local_rank()).stream());
const P2pIpcHandle& ipc_handles =
ipc_handle_cache_.get(p2p_comm, expr_evaluator_);
getZcopy::SendWait(ipc_handles, current_stream);
}
} else {
auto it = works_.find(communication);
if (it == works_.end()) {
return;
}
auto& work = it->second;
if (work != nullptr) {
work->wait();
}
works_.erase(communication);
}
works_.erase(communication);
}

namespace {
Expand Down
3 changes: 3 additions & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <host_ir/container.h>
#include <host_ir/host_ir.h>
#include <multidevice/communicator.h>
#include <multidevice/ipc_handle.h>
#include <runtime/executor.h>
#include <runtime/executor_abstract.h>
#include <runtime/executor_params.h>
Expand Down Expand Up @@ -129,6 +130,7 @@ class HostIrEvaluator final : public OptOutDispatch {
void handle(MatmulOp* matmul) override;
void handle(LinearOp* linear) override;
void handle(kir::Allocate* allocate) override;
void handle(ShareMemHandles* share_mem_handles) override;
void unhandled(Statement* stmt) override;

c10::cuda::CUDAStream getCUDAStream(Stream* stream);
Expand All @@ -145,6 +147,7 @@ class HostIrEvaluator final : public OptOutDispatch {
std::unordered_map<StreamKey, c10::cuda::CUDAStream> streams_;
std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
const int64_t my_device_index_;
IpcHandleCache ipc_handle_cache_;
};

} // namespace hir
Expand Down
28 changes: 28 additions & 0 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,34 @@ std::string EndCoalescing::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
}

ShareMemHandles::ShareMemHandles(
IrBuilderPasskey passkey,
std::vector<P2PCommunication*> communications)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<HostIrContainer>(),
this,
"must be registered in a HostIrContainer");
addDataAttribute(std::move(communications));
}

NVFUSER_DEFINE_CLONE_AND_CREATE(ShareMemHandles)

std::string ShareMemHandles::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "ShareMemHandles(";
for (auto communication : communications()) {
ss << communication->toInlineString() << ", ";
}
ss << std::endl;
return ss.str();
}

std::string ShareMemHandles::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
}

} // namespace hir

} // namespace nvfuser
25 changes: 25 additions & 0 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,31 @@ class EndCoalescing : public Expr {
}
};

class ShareMemHandles : public Expr {
public:
using Expr::Expr;
ShareMemHandles(
IrBuilderPasskey passkey,
std::vector<P2PCommunication*> communications);

ShareMemHandles(const ShareMemHandles& other) = delete;
ShareMemHandles& operator=(const ShareMemHandles& other) = delete;
ShareMemHandles(ShareMemHandles&& other) = delete;
ShareMemHandles& operator=(ShareMemHandles&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::ShareMemHandles";
}

const std::vector<P2PCommunication*>& communications() const {
return attribute<std::vector<P2PCommunication*>>(0);
}
};

} // namespace hir

} // namespace nvfuser
52 changes: 20 additions & 32 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,44 +213,33 @@ std::string Communication::toInlineString(int indent_size) const {
return toString(indent_size);
}

std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type) {
switch (type) {
case P2PCommunicationType::SEND:
os << "send";
break;
case P2PCommunicationType::RECV:
os << "recv";
break;
default:
NVF_THROW("unrecognized P2PCommunicationType: ", type);
}
return os;
}

P2PCommunication::P2PCommunication(
IrBuilderPasskey passkey,
P2PCommunicationType type,
TensorView* buffer,
Val* peer)
Val* dst,
Val* src,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(buffer);
addDataAttribute(type);
addAttribute(peer);
addAttribute(dst);
addAttribute(src);
addDataAttribute(backend);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(P2PCommunication)

std::string P2PCommunication::toString(const int indent_size) const {
std::string P2PCommunication::toInlineString(const int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "P2PCommunication " << name() << " ("
<< "type=" << type() << ", "
<< "buffer=" << buffer() << ", "
<< "peer=" << peer() << ")\n";
<< "dst=" << dst() << ", "
<< "src=" << src() << ", "
<< "backend=" << backend() << ")";
return ss.str();
}

std::string P2PCommunication::toInlineString(int indent_size) const {
return toString(indent_size);
std::string P2PCommunication::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

namespace {
Expand Down Expand Up @@ -584,19 +573,18 @@ c10::intrusive_ptr<c10d::Work> postRecv(
c10::intrusive_ptr<c10d::Work> postSingleCommunication(
P2PCommunication* communication,
DeviceIdxType my_device_index,
DeviceIdxType peer,
DeviceIdxType dst,
DeviceIdxType src,
c10d::Backend* backend,
at::Tensor buffer) {
NVF_ERROR(backend != nullptr);

switch (communication->type()) {
case P2PCommunicationType::SEND:
return postSend(communication, my_device_index, peer, backend, buffer);
case P2PCommunicationType::RECV:
return postRecv(communication, my_device_index, peer, backend, buffer);
default:
NVF_THROW("Wrong communication type: ", communication->type());
return nullptr;
if (my_device_index == src) {
return postSend(communication, my_device_index, dst, backend, buffer);
} else if (my_device_index == dst) {
return postRecv(communication, my_device_index, src, backend, buffer);
} else {
return nullptr;
}
}

Expand Down
26 changes: 14 additions & 12 deletions csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,16 @@ class Communication : public Expr {
void validate();
};

enum class P2PCommunicationType { SEND, RECV };

std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type);

class P2PCommunication : public Expr {
public:
using Expr::Expr;

P2PCommunication(
IrBuilderPasskey passkey,
P2PCommunicationType type,
TensorView* buffer,
Val* peer);
Val* dst,
Val* src,
CommunicatorBackend backend = CommunicatorBackend::kNccl);

P2PCommunication(const P2PCommunication& other) = delete;
P2PCommunication& operator=(const P2PCommunication& other) = delete;
Expand All @@ -143,17 +140,21 @@ class P2PCommunication : public Expr {
return "P2PCommunication";
}

P2PCommunicationType type() const {
return attribute<P2PCommunicationType>(0);
}

TensorView* buffer() const {
return input(0)->as<TensorView>();
}

Val* peer() const {
Val* dst() const {
return attributeVal(0);
}

Val* src() const {
return attributeVal(1);
}

auto backend() const {
return attribute<CommunicatorBackend>(2);
}
};

// The method "post" triggers the execution of the communication. This call is
Expand Down Expand Up @@ -225,7 +226,8 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
c10::intrusive_ptr<c10d::Work> postSingleCommunication(
P2PCommunication* communication,
DeviceIdxType my_device_index,
DeviceIdxType peer,
DeviceIdxType dst,
DeviceIdxType src,
c10d::Backend* backend,
at::Tensor buffer);

Expand Down
3 changes: 3 additions & 0 deletions csrc/multidevice/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ std::ostream& operator<<(std::ostream& out, const CommunicatorBackend& cb) {
case CommunicatorBackend::kGloo:
out << "GLOO";
break;
case CommunicatorBackend::kCuda:
out << "CUDA";
break;
}
return out;
}
Expand Down
7 changes: 4 additions & 3 deletions csrc/multidevice/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ namespace nvfuser {

using RankType = DeviceIdxType;

// Supported backends. TODO: gloo untested
enum class CommunicatorBackend { kNccl, kUcc, kGloo };

std::ostream& operator<<(std::ostream& out, const CommunicatorBackend& cb);

#ifdef USE_C10D_NCCL
Expand Down Expand Up @@ -123,6 +120,10 @@ class Communicator {
return false;
}

auto getTcpStore() {
return store_;
}

private:
Communicator(
CommunicatorBackend backend = comm_backend_default,
Expand Down
Loading
Loading