Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

P2p comms through Cuda Ipc #3883

Draft
wants to merge 60 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
2c04df9
working simple benchmark
samnordmann Oct 23, 2024
af36cf1
minor
samnordmann Oct 25, 2024
68b858a
test script
samnordmann Oct 25, 2024
0c3493b
minor
samnordmann Oct 28, 2024
b30b44b
add nsight profiling
samnordmann Oct 29, 2024
0592a13
nsight and tl/nccl/ sync mode
samnordmann Oct 31, 2024
0037b1e
add cuStreamWriteValue but linkage error
samnordmann Nov 4, 2024
ec71e23
multiple pgs
samnordmann Nov 4, 2024
a15fdfc
reenable cuStreamValue32
samnordmann Nov 4, 2024
6682a33
add tl/cuda and ec/cuda flags in bash test script
samnordmann Nov 4, 2024
b01f1f4
add option to unfuse loops
samnordmann Nov 4, 2024
ea7fd37
add cuda graphs. Only working for NCCL and S1 bc there is a syncStrea…
samnordmann Nov 5, 2024
9dddac2
write matmul to sliced output
samnordmann Nov 26, 2024
faf8bbe
wip cuStreamWriteValue not working
samnordmann Nov 28, 2024
a6b5fd7
dummy benchmark
samnordmann Dec 2, 2024
8d927bf
add pre post comms option
samnordmann Dec 2, 2024
d9c581c
add pre post comms option
samnordmann Dec 2, 2024
bfc7fa6
cleanup test script
samnordmann Dec 6, 2024
1a1138c
update
samnordmann Jan 8, 2025
743185d
Merge branch 'overlap_bench/first_experiments' of github.com:samnordm…
samnordmann Jan 16, 2025
a2b1650
Merge branch 'main' of github.com:NVIDIA/Fuser into overlap_bench/fir…
samnordmann Jan 16, 2025
e037ee5
test with stream parallel type and host IR
samnordmann Jan 16, 2025
8328c28
add support for other dtypes
samnordmann Jan 20, 2025
2fecf02
remove trace print
samnordmann Jan 22, 2025
26f1f7a
add stub files
samnordmann Jan 23, 2025
03b0147
first working example opening cuda ipc handles
samnordmann Jan 24, 2025
7625fab
adding a non-working example with cudaDeviceCanAccessPeer
samnordmann Jan 24, 2025
f703abd
cleanup
samnordmann Jan 28, 2025
abf5c17
AllgatherThroughCudaMemcpyAsync
samnordmann Jan 28, 2025
836d599
refactor to expose choice of backend
samnordmann Jan 29, 2025
e09dd58
add backend type to P2PCommunication
samnordmann Jan 29, 2025
2fbea13
Merge branch 'main' of https://github.com/NVIDIA/Fuser into gpu_comms…
samnordmann Jan 29, 2025
d87c8e6
Merge branch 'overlap_bench/AG_Matmul_with_stream_parallel_type' of g…
samnordmann Jan 29, 2025
6371746
wip
samnordmann Jan 30, 2025
b700a31
working chkpt
samnordmann Jan 30, 2025
1838d1e
remove prints
samnordmann Jan 30, 2025
21eed4a
working chkpt
samnordmann Jan 30, 2025
f455c70
reenable profiling
samnordmann Jan 30, 2025
5a27b7e
fix cache for ipc handles
samnordmann Jan 31, 2025
356feeb
synchronize running stream with original stream at the beginning of p…
samnordmann Feb 3, 2025
4c0736a
lint
samnordmann Feb 3, 2025
7fca035
wip. The send and recv Expr* need to be matched together for associat…
samnordmann Feb 5, 2025
371554e
working chkpt well prepared for two ranks
samnordmann Feb 5, 2025
c7c0404
change signature of P2Pcomms to accept src and dst
samnordmann Feb 7, 2025
6c20a20
working chkpt with get zcopy
samnordmann Feb 11, 2025
f7409b2
working checkpt with many ranks
samnordmann Feb 12, 2025
08f8fe0
chkpt non blocking
samnordmann Feb 12, 2025
de843bb
harden tests by removing hard syncs
samnordmann Feb 12, 2025
4dc9936
use cudaMemcpyAsync
samnordmann Feb 12, 2025
4e05609
clean and lint
samnordmann Feb 12, 2025
326b683
Move distributed tensors to separate file
samnordmann Feb 12, 2025
cf8991c
rename DistributedBuffer to IpcHandle
samnordmann Feb 12, 2025
541fe80
working chkpt. Added in the commit the new files that were forgotten …
samnordmann Feb 12, 2025
a496004
refactor
samnordmann Feb 12, 2025
263d95c
minor cleanup
samnordmann Feb 12, 2025
106d295
lint
samnordmann Feb 12, 2025
ed69f75
minor
samnordmann Feb 12, 2025
5672eca
Merge branch 'main' of github.com:NVIDIA/Fuser into gpu_comms/add_p2p
samnordmann Feb 12, 2025
929ae0d
minor
samnordmann Feb 12, 2025
359779d
move p2p runtime in separate file
samnordmann Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/mma_type.cpp
${NVFUSER_SRCS_DIR}/multidevice/communication.cpp
${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp
${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp
${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp
${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp
${NVFUSER_SRCS_DIR}/multidevice/executor.cpp
${NVFUSER_SRCS_DIR}/multidevice/utils.cpp
Expand Down Expand Up @@ -349,6 +351,7 @@ target_link_libraries(codegen_internal PUBLIC
${LIBCUPTI}
${TORCH_LIBRARIES}
dl
cuda
)

add_library(nvfuser_codegen SHARED $<TARGET_OBJECTS:codegen_internal>)
Expand Down Expand Up @@ -615,6 +618,16 @@ if(BUILD_TEST)
target_include_directories(${RNG_TEST_KERNELS} PRIVATE "${NVFUSER_ROOT}")
endif()

if(BUILD_TEST)
set(MULTIDEVICE_TEST_KERNELS "${NVFUSER_TESTS}_multidevice_kernels")
add_library(${MULTIDEVICE_TEST_KERNELS} SHARED ${NVFUSER_ROOT}/tests/cpp/multidevice_kernels.cu)

# CUDA 11 does not support C++20, so hard code C++17 here
set_property(TARGET ${MULTIDEVICE_TEST_KERNELS} PROPERTY CXX_STANDARD 17)
target_link_libraries(${MULTIDEVICE_TEST_KERNELS} PRIVATE torch ${TORCH_LIBRARIES} codegen_internal)
target_include_directories(${MULTIDEVICE_TEST_KERNELS} PRIVATE "${NVFUSER_ROOT}")
endif()

function(add_test_without_main TEST_NAME TEST_SRC ADDITIONAL_LINK)
list(APPEND TEST_SRC
${NVFUSER_ROOT}/tests/cpp/utils.cpp
Expand Down Expand Up @@ -673,8 +686,9 @@ if(BUILD_TEST)
${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_transformer.cpp
${NVFUSER_ROOT}/tests/cpp/test_multidevice_gpu_comms.cpp
)
add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" "")
add_test_without_main(test_multidevice "${MULTIDEVICE_TEST_SRCS}" ${MULTIDEVICE_TEST_KERNELS})
list(APPEND TEST_BINARIES test_multidevice)

set(MULTIDEVICE_TUTORIAL_SRCS)
Expand Down
7 changes: 7 additions & 0 deletions bench/process_outputs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/bash

FILE="/opt/pytorch/Fuser/bench/logs/${1}/info"

cat $FILE | grep "rank 0: " #| awk '{print $4}'

# | grep -E 'Streams32\b'
90 changes: 90 additions & 0 deletions bench/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/bin/bash
EXPERIMENT=CUDA_tests
DATE=$(date +%Y%m%d-%H%M)
LOG_BASE="/opt/pytorch/Fuser/bench/logs"

NP=8
BACKEND=CUDA
M=32768
K=32768
N=1024

DTYPE="float" #"__half" # float, __bfloat

S=8
Streams=3
Pgs=1

# M=131072 #32768
# K=131072
# N=32768 #1024
# L=1048576 #268435456 #67108864 #131072
# PRE_COMM="_pre_comm"
# POST_COMM="_post_comm"
# UNFUSE="_unfused"
# GRAPH="_WithCudaGraph"
# cuStreamWrite=WithcuStreamWriteValue32_
# GTEST_PREFIX="OverlapBenchmark.PipelinedAGMatmulBenchmark/"
# GTEST_PREFIX="DummyOverlapBenchmark.PipelinedAGMatmulBenchmark/"
GTEST_PREFIX="OverlapBenchmark.PipelinedAGMatmulBenchmarkStreamParallelType/"
GTEST_POSTFIX="${BACKEND}_S${S}_M${M}_K${K}_N${N}_Streams${Streams}_${DTYPE}_${cuStreamWrite}Pgs${Pgs}${UNFUSE}${GRAPH}"
# GTEST_POSTFIX="${BACKEND}_M${M}_K${K}_N${N}_L${L}${PRE_COMM}${POST_COMM}"
export GTEST_FILTER="${GTEST_PREFIX}${GTEST_POSTFIX}"
echo "gtest filter: $GTEST_FILTER" | tee -a $LOG_FILE_INFO

MPIFLAGS=" -np $NP"

# MPIFLAGS+=" -x NCCL_P2P_NET_CHUNKSIZE=2MB"
# MPIFLAGS+=" -x NCCL_DEBUG=TRACE" #INFO
# MPIFLAGS+=" -x NCCL_MAX_NCHANNELS=1"

MPIFLAGS+=" -x UCC_CL_BASIC_TLS=nccl"
# MPIFLAGS+=" -x UCC_TL_NCCL_SYNC=event"

# MPIFLAGS+=" -x UCC_CL_BASIC_TLS=cuda"
# MPIFLAGS+=" -x UCC_TL_CUDA_SCRATCH_SIZE=32mb"
# MPIFLAGS+=" -x UCC_TL_CUDA_ALLGATHER_RING_MAX_RINGS=32"
# MPIFLAGS+=" -x UCC_TL_CUDA_ALLGATHER_RING_NUM_CHUNKS=32"

# MPIFLAGS+=" -x UCC_EC_CUDA_EXEC_NUM_WORKERS=8"
# MPIFLAGS+=" -x UCC_EC_CUDA_USE_COOPERATIVE_LAUNCH=0"
# MPIFLAGS+=" -x UCC_EC_CUDA_STREAM_TASK_MODE=driver"
# MPIFLAGS+=" -x UCC_EC_CUDA_STREAM_TASK_MODE=kernel"
# MPIFLAGS+=" -x UCC_EC_CUDA_EXEC_COPY_LARGE_THRESH=1M"
# MPIFLAGS+=" -x UCC_EC_CUDA_EXEC_NUM_THREADS=512"

# MPIFLAGS+=" -x UCC_CL_BASIC_TLS=ucp"
# MPIFLAGS+=" -x UCX_RNDV_THRESH=0 -x UCX_TLS=ib,cuda_copy"
# MPIFLAGS+=" -x UCX_RNDV_SCHEME=put_zcopy"
# MPIFLAGS+=" -x UCX_RNDV_SCHEME=get_zcopy"


MPIFLAGS+=" -x UCX_NET_DEVICES=mlx5_0:1"
# MPIFLAGS+=" -x UCC_CL_BASIC_TLS=^sharp,mlx5"
# MPIFLAGS+=" -x UCC_COLL_TRACE=info"
# MPIFLAGS+=" -x UCC_LOG_LEVEL=debug"
# MPIFLAGS+=" -x TORCH_NCCL_AVOID_RECORD_STREAMS=1"
# MPIFLAGS+=" -x CUDA_DEVICE_MAX_CONNECTIONS=2"


export LOGS="${LOG_BASE}/${EXPERIMENT}_${BACKEND}_${DATE}"
mkdir -p $LOGS
export LOG_FILE_INFO="${LOGS}/info.txt"
echo "Writing to $LOG_FILE_INFO" | tee -a $LOG_FILE_INFO

echo "mpi flags: $MPIFLAGS" | tee -a $LOG_FILE_INFO

TEST_CMD="$BUILD_DIRECTORY/test_multidevice --gtest_filter=${GTEST_FILTER}"
echo "test cmd: $TEST_CMD" | tee -a $LOG_FILE_INFO

MPICMD="mpirun $MPIFLAGS $TEST_CMD"
echo $MPICMD | tee -a $LOG_FILE_INFO

# opt/pytorch/scripts/nsight/install-nsight.sh
NSYS=$(sudo which nsys)
NSYSCMD="${NSYS} profile --stats=false -w true -t cublas,cuda,nvtx,osrt,mpi,ucx -o ${LOGS}/${GTEST_POSTFIX} --capture-range-end stop --capture-range=cudaProfilerApi --cudabacktrace=memory,sync,kernel,other"

CMD="${NSYSCMD} ${MPICMD}"
sudo /bin/sh -c "echo '1' > /proc/sys/kernel/perf_event_paranoid"
echo $CMD | tee -a ${LOG_FILE_INFO}
$CMD | tee -a ${LOG_FILE_INFO}
3 changes: 2 additions & 1 deletion csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ class Val;
f(Wait); \
f(Synchronize); \
f(StartCoalescing); \
f(EndCoalescing);
f(EndCoalescing); \
f(ShareMemHandles);

// Forward declarations for all Val and Expr types

Expand Down
102 changes: 79 additions & 23 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <instrumentation.h>
#include <ir/utils.h>
#include <multidevice/communication.h>
#include <multidevice/cuda_p2p.h>
#include <multidevice/utils.h>
#include <options.h>
#include <runtime/allocations.h>
Expand Down Expand Up @@ -69,7 +70,8 @@ void HostIrExecutor::compile(Fusion* fusion) {
} else {
std::vector<Expr*> exprs = fusion->exprs();
for (Expr* e : exprs) {
std::vector<Expr*> communications = HostIrLower::lower(cloner.clone(e));
HostIrLower lower;
std::vector<Expr*> communications = lower.lower(cloner.clone(e));
for (auto* communication : communications) {
host_ir_container_->pushBackTopLevelExprs(communication);
}
Expand Down Expand Up @@ -189,7 +191,7 @@ HostIrEvaluator::HostIrEvaluator(
: container_(std::move(container)),
communicator_(communicator),
params_(params),
my_device_index_(communicator_ ? communicator_->deviceId() : 0) {
my_local_device_index_(communicator_ ? communicator_->local_rank() : 0) {
const DeviceIdxType device_index =
(communicator_ != nullptr && communicator_->is_available())
? communicator_->deviceId()
Expand Down Expand Up @@ -279,13 +281,13 @@ void HostIrEvaluator::handle(GetCurrentStream* get_current_stream) {
streams_.insert(
{get_current_stream->stream(),
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))});
static_cast<c10::DeviceIndex>(my_local_device_index_))});
}

void HostIrEvaluator::handle(Synchronize* synchronize) {
cudaStream_t current_stream =
c10::cuda::getCurrentCUDAStream(
static_cast<c10::DeviceIndex>(my_device_index_))
static_cast<c10::DeviceIndex>(my_local_device_index_))
.stream();
cudaStream_t stream_to_sync = getCUDAStream(synchronize->stream()).stream();

Expand Down Expand Up @@ -408,6 +410,11 @@ void HostIrEvaluator::handle(PostOnStream* post_ir) {
}
}

void HostIrEvaluator::handle(ShareMemHandles* share_mem_handles) {
ipc_handle_cache_.exchangeHandles(
share_mem_handles->communications(), expr_evaluator_);
}

void HostIrEvaluator::handle(Communication* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
Expand All @@ -418,40 +425,89 @@ void HostIrEvaluator::handle(Communication* communication) {
at::Tensor output_tensor =
getKnownTensorOrUndefined(communication->output(0), expr_evaluator_);

c10d::Backend* backend =
communicator_->getBackendForTeam(communication->team(), std::nullopt);
works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
backend,
input_tensor,
output_tensor);
CommunicatorBackend backend_type = communication->backend();

if (backend_type != CommunicatorBackend::kCuda) {
c10d::Backend* backend =
communicator_->getBackendForTeam(communication->team(), backend_type);
works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
backend,
input_tensor,
output_tensor);
return;
}

NVF_ERROR(communication->type() == CommunicationType::Allgather);
}

void HostIrEvaluator::handle(P2PCommunication* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

const int64_t my_rank = communicator_->deviceId();
const auto dst = expr_evaluator_.evaluate(communication->dst()).as<int64_t>();
const auto src = expr_evaluator_.evaluate(communication->src()).as<int64_t>();
const bool is_sender = my_rank == src;
const bool is_receiver = my_rank == dst;
if (!(is_sender ^ is_receiver)) {
return;
}

CommunicatorBackend backend_type = communication->backend();
at::Tensor buffer =
getKnownTensorOrUndefined(communication->buffer(), expr_evaluator_);

works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
expr_evaluator_.evaluate(communication->peer()).as<int64_t>(),
communicator_->getWorld(),
buffer);
if (backend_type != CommunicatorBackend::kCuda) {
works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
expr_evaluator_.evaluate(communication->dst()).as<int64_t>(),
expr_evaluator_.evaluate(communication->src()).as<int64_t>(),
communicator_->getWorld(),
buffer);
return;
}

const P2pIpcHandle& ipc_handles =
ipc_handle_cache_.get(communication, expr_evaluator_);
const auto current_stream = static_cast<CUstream>(
c10::cuda::getCurrentCUDAStream(my_local_device_index_).stream());
if (is_receiver) {
getZcopy::RecvPost(ipc_handles, buffer.numel() * buffer.element_size(), current_stream);
} else /*sender*/ {
getZcopy::SendPost(ipc_handles, current_stream);
}
}

void HostIrEvaluator::handle(Wait* wait) {
Expr* communication = wait->communication();
NVF_ERROR(works_.find(communication) != works_.end(), "no wait req");
auto& work = works_.at(communication);
if (work != nullptr) {
work->wait();
auto* p2p_comm = dynamic_cast<P2PCommunication*>(communication);
if (p2p_comm && p2p_comm->backend() != CommunicatorBackend::kCuda) {
auto it = works_.find(communication);
if (it == works_.end()) {
return;
}
auto& work = it->second;
if (work != nullptr) {
work->wait();
}
works_.erase(communication);
return;
}

const auto src = expr_evaluator_.evaluate(p2p_comm->src()).as<int64_t>();
const auto dst = expr_evaluator_.evaluate(p2p_comm->dst()).as<int64_t>();
const int64_t my_rank = communicator_->deviceId();
if (my_rank == src && src != dst) {
const auto current_stream = static_cast<CUstream>(
c10::cuda::getCurrentCUDAStream(my_local_device_index_).stream());
const P2pIpcHandle& ipc_handles =
ipc_handle_cache_.get(p2p_comm, expr_evaluator_);
getZcopy::SendWait(ipc_handles, current_stream);
}
works_.erase(communication);
}

namespace {
Expand Down
5 changes: 4 additions & 1 deletion csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <host_ir/container.h>
#include <host_ir/host_ir.h>
#include <multidevice/communicator.h>
#include <multidevice/ipc_handle.h>
#include <runtime/executor.h>
#include <runtime/executor_abstract.h>
#include <runtime/executor_params.h>
Expand Down Expand Up @@ -129,6 +130,7 @@ class HostIrEvaluator final : public OptOutDispatch {
void handle(MatmulOp* matmul) override;
void handle(LinearOp* linear) override;
void handle(kir::Allocate* allocate) override;
void handle(ShareMemHandles* share_mem_handles) override;
void unhandled(Statement* stmt) override;

c10::cuda::CUDAStream getCUDAStream(Stream* stream);
Expand All @@ -144,7 +146,8 @@ class HostIrEvaluator final : public OptOutDispatch {
using StreamKey = std::variant<int64_t, Stream*>;
std::unordered_map<StreamKey, c10::cuda::CUDAStream> streams_;
std::unordered_map<Expr*, c10::intrusive_ptr<c10d::Work>> works_;
const int64_t my_device_index_;
const int64_t my_local_device_index_;
IpcHandleCache ipc_handle_cache_;
};

} // namespace hir
Expand Down
28 changes: 28 additions & 0 deletions csrc/host_ir/host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,34 @@ std::string EndCoalescing::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
}

ShareMemHandles::ShareMemHandles(
IrBuilderPasskey passkey,
std::vector<P2PCommunication*> communications)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(
passkey.ir_container_->isA<HostIrContainer>(),
this,
"must be registered in a HostIrContainer");
addDataAttribute(std::move(communications));
}

NVFUSER_DEFINE_CLONE_AND_CREATE(ShareMemHandles)

std::string ShareMemHandles::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "ShareMemHandles(";
for (auto communication: communications()) {
ss << communication->toInlineString() << ", ";
}
ss << std::endl;
return ss.str();
}

std::string ShareMemHandles::toInlineString(int indent_size) const {
NVF_CHECK(false, "Cannot be printed inline");
}

} // namespace hir

} // namespace nvfuser
Loading
Loading