From 96f94eaf8411b2ac0f28c041b6891b9faff7db51 Mon Sep 17 00:00:00 2001 From: zhushuang <974198603@qq.com> Date: Thu, 12 Feb 2026 16:20:36 +0800 Subject: [PATCH 1/3] issue/1021 - feat: support bf16 for mccl (build failed due to missing mcclBfloat16) --- src/infiniccl/moore/infiniccl_moore.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/infiniccl/moore/infiniccl_moore.cc b/src/infiniccl/moore/infiniccl_moore.cc index b58b2d63a..4dee4aebd 100644 --- a/src/infiniccl/moore/infiniccl_moore.cc +++ b/src/infiniccl/moore/infiniccl_moore.cc @@ -23,6 +23,8 @@ inline mcclDataType_t getMcclDtype(infiniDtype_t datatype) { return mcclFloat; case INFINI_DTYPE_F16: return mcclHalf; + case INFINI_DTYPE_BF16: + return mcclBfloat16; default: std::abort(); return mcclHalf; @@ -83,9 +85,7 @@ infiniStatus_t allReduce( infinicclComm_t comm, infinirtStream_t stream) { - if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) { - return INFINI_STATUS_BAD_PARAM; - } + CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); CHECK_MCCL(mcclAllReduce(sendbuf, recvbuf, count, getMcclDtype(datatype), getMcclRedOp(op), getMcclComm(comm), getMusaStream(stream))); From a8d2c7328a5e8d4b2a5556191a65e49322fc12a1 Mon Sep 17 00:00:00 2001 From: "zichen.wang" Date: Thu, 26 Feb 2026 14:54:06 +0800 Subject: [PATCH 2/3] fix moore can not compile mccl bf16 --- xmake/moore.lua | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xmake/moore.lua b/xmake/moore.lua index 2cd55e8b2..908a332cc 100644 --- a/xmake/moore.lua +++ b/xmake/moore.lua @@ -16,7 +16,7 @@ rule("mu") local mcc = MUSA_ROOT .. "/bin/mcc" local includedirs = table.concat(target:get("includedirs"), " ") - local args = {"-c", sourcefile, "-o", objectfile, "-I" .. MUSA_ROOT .. "/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"} + local args = {"--cuda-gpu-arch=mp_31", "-c", sourcefile, "-o", objectfile, "-I" .. MUSA_ROOT .. "/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"} for _, includedir in ipairs(target:get("includedirs")) do table.insert(args, "-I" .. includedir) end @@ -76,6 +76,8 @@ target("infiniccl-moore") if has_config("ccl") then add_links("libmccl.so") add_files("../src/infiniccl/moore/*.cc") + add_defines("MARCH_TYPE=310") + add_cxxflags("-Wno-unused-function") end set_languages("cxx17") From 17996414c708788bda32410889a6b952e99e28d8 Mon Sep 17 00:00:00 2001 From: "zichen.wang" Date: Mon, 9 Mar 2026 11:32:26 +0800 Subject: [PATCH 3/3] cr --- src/infiniccl/moore/custom_all_reduce.h | 41 ++ src/infiniccl/moore/custom_all_reduce.mu | 152 +++++ src/infiniccl/moore/custom_all_reduce.muh | 711 ++++++++++++++++++++++ src/infiniccl/moore/infiniccl_moore.cc | 177 +++++- src/infiniccl/moore/infiniccl_moore.h | 20 + src/infiniccl/moore/utils.h | 481 +++++++++++++++ xmake/moore.lua | 14 +- 7 files changed, 1588 insertions(+), 8 deletions(-) create mode 100644 src/infiniccl/moore/custom_all_reduce.h create mode 100644 src/infiniccl/moore/custom_all_reduce.mu create mode 100644 src/infiniccl/moore/custom_all_reduce.muh create mode 100644 src/infiniccl/moore/utils.h diff --git a/src/infiniccl/moore/custom_all_reduce.h b/src/infiniccl/moore/custom_all_reduce.h new file mode 100644 index 000000000..c69593640 --- /dev/null +++ b/src/infiniccl/moore/custom_all_reduce.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// namespace custom_allreduce { + +using fptr_t = int64_t; + +fptr_t init_custom_ar( + const std::vector& fake_ipc_ptrs, + void* rank_data, + size_t rank_data_sz, + int64_t rank, + bool full_nvlink); + +void all_reduce( + fptr_t _fa, + void* inp, + void* out, + size_t rank_data_sz, + mcclDataType_t datatype, + fptr_t _reg_buffer, + int64_t reg_buffer_sz_bytes, + musaStream_t stream); + +int64_t meta_size(); + +void register_buffer( + fptr_t _fa, + const std::vector& fake_ipc_ptrs); + +// } // namespace custom_allreduce + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/src/infiniccl/moore/custom_all_reduce.mu b/src/infiniccl/moore/custom_all_reduce.mu new file mode 100644 index 000000000..b7a7a6444 --- /dev/null +++ b/src/infiniccl/moore/custom_all_reduce.mu @@ -0,0 +1,152 @@ +// Adapted from: https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu +// #include "torch_musa/csrc/aten/musa/Exceptions.h" +// #include "torch_musa/csrc/core/MUSAGuard.h" +// #include "torch_musa/csrc/core/MUSAStream.h" + +#include +#include +#include "custom_all_reduce.muh" + +// Fake pointer type, must match fptr_t type in ops.h. +// We use this type alias to indicate when pointers are passed in as int64_t. +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +extern "C" fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, void* rank_data, size_t rank_data_sz, int64_t rank, bool full_nvlink) { + int world_size = fake_ipc_ptrs.size(); + if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); + + sglang::Signal* ipc_ptrs[8]; + + for (int i = 0; i < world_size; i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } + + return (fptr_t) new sglang::CustomAllreduce( + ipc_ptrs, rank_data, rank_data_sz, rank, world_size, full_nvlink); +} + +/** + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() + * because it allows transpose of contiguous slice (i.e. slicing the first + * dimension). Currently, we require this because stride information is not + * passed into the kernels and we treat input tensors as flat. + * + * Examples + * A = torch.zeros(3, 3, 3) + * 1. A: OK + * 2. A[1:]: OK + * 3. A.permute(2, 0, 1): OK + * 4. A[1:].permute(2, 0, 1): OK + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ +// bool _is_weak_contiguous(torch::Tensor& t) { +// return t.is_contiguous() || +// (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); +// } + +/** + * Performs an out-of-place allreduce and stores result in out. + * + * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. + * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first + * copied into _reg_buffer. + */ +extern "C" void all_reduce(fptr_t _fa, void* inp, void* out, size_t rank_data_sz, mcclDataType_t datatype, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes, musaStream_t stream) { + auto fa = reinterpret_cast(_fa); + // const at::musa::OptionalMUSAGuard device_guard(device_of(inp)); + // auto stream = c10::musa::getCurrentMUSAStream().stream(); + + // TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + // TORCH_CHECK_EQ(inp.numel(), out.numel()); + // TORCH_CHECK(_is_weak_contiguous(out)); + // TORCH_CHECK(_is_weak_contiguous(inp)); + // auto input_size = inp.numel() * inp.element_size(); + size_t rank_data_element_sz = 0; + switch (datatype) { + case mcclFloat: + rank_data_element_sz = 4; + break; + case mcclHalf: + rank_data_element_sz = 2; + break; + case mcclBfloat16: + rank_data_element_sz = 2; + break; + default: + throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16"); + } + + auto input_size = rank_data_sz * rank_data_element_sz; + auto reg_buffer = reinterpret_cast(_reg_buffer); + + if (reg_buffer) { + // TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); + CHECK_MUSA_SUCCESS(musaMemcpyAsync(reg_buffer, inp, input_size, musaMemcpyDeviceToDevice, stream)); + } else { + reg_buffer = inp; + } + + switch (datatype) { + case mcclFloat: { + fa->allreduce( + stream, reinterpret_cast(reg_buffer), reinterpret_cast(out), (int)rank_data_sz); + break; + } + case mcclHalf: { + fa->allreduce( + stream, reinterpret_cast(reg_buffer), reinterpret_cast(out), (int)rank_data_sz); + break; + } + case mcclBfloat16: { + fa->allreduce<__mt_bfloat16>( + stream, reinterpret_cast<__mt_bfloat16*>(reg_buffer), reinterpret_cast<__mt_bfloat16*>(out), (int)rank_data_sz); + break; + } + default: + throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16"); + } +} + +void dispose(fptr_t _fa) { + delete reinterpret_cast(_fa); +} + +extern "C" int64_t meta_size() { + return sizeof(sglang::Signal); +} + +extern "C" void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { + auto fa = reinterpret_cast(_fa); + // TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); + void* ipc_ptrs[8]; + for (int i = 0; i < fake_ipc_ptrs.size(); i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } + fa->register_buffer(ipc_ptrs); +} + +// Use vector to represent byte data for python binding compatibility. +std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); + std::vector bytes(handle.begin(), handle.end()); + return std::make_tuple(bytes, offsets); +} + +// Use vector to represent byte data for python binding compatibility. +void register_graph_buffers( + fptr_t _fa, const std::vector>& handles, const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); + std::vector bytes; + bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + bytes.reserve(handles.size()); + fa->register_graph_buffers(bytes, offsets); +} diff --git a/src/infiniccl/moore/custom_all_reduce.muh b/src/infiniccl/moore/custom_all_reduce.muh new file mode 100644 index 000000000..b9d1e1f95 --- /dev/null +++ b/src/infiniccl/moore/custom_all_reduce.muh @@ -0,0 +1,711 @@ +// Adapted from https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cuh +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// #include "utils.h" + +#define CHECK_MUSA_SUCCESS(cmd) \ + do { \ + musaError_t e = cmd; \ + if (e != musaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + musaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace sglang { + +#ifdef USE_MUSA +constexpr int kMaxBlocks = 60; +constexpr int kDefaultThreads = 1024; +constexpr int kDefaultBlockLimit = 60; +constexpr int kMaxThreadsPerBlock = 1024; +#else +constexpr int kMaxBlocks = 36; +constexpr int kDefaultThreads = 512; +constexpr int kDefaultBlockLimit = 36; +constexpr int kMaxThreadsPerBlock = 512; +#endif +// Counter may overflow, but it's fine since unsigned int overflow is +// well-defined behavior. +using FlagType = uint32_t; +struct Signal { + alignas(128) FlagType self_counter[kMaxBlocks][8]; + // Two sets of peer counters are needed for two syncs. The reason is that + // it's possible for peer GPU block to arrive at the second sync point while + // the current GPU block haven't passed the first sync point. Thus, peer GPU + // may write counter+1 while current GPU is busy waiting for counter. We use + // alternating counter array to avoid this possibility. + alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; +}; + +struct __align__(16) RankData { +#ifdef USE_MUSA + const void* ptrs[8]; +#else + const void* __restrict__ ptrs[8]; +#endif +}; + +struct __align__(16) RankSignals { + Signal* signals[8]; +}; + +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t { + T data[sz]; + using type = T; + static constexpr int size = sz; +}; + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t { + // the (P)acked type for load/store + using P = array_t; + // the (A)ccumulator type for reduction + using A = array_t; +}; + +#define DINLINE __device__ __forceinline__ + +// scalar cast functions +DINLINE float upcast_s(half val) { + return __half2float(val); +} + +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) { + return __float2half(val); +} + +// scalar add functions +// for some reason when compiling with Pytorch, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half& assign_add(half& a, half b) { + a = __hadd(a, b); + return a; +} +DINLINE float& assign_add(float& a, float b) { + return a += b; +} + +#if (__MUSA_ARCH__ >= 800 || !defined(__MUSA_ARCH__)) || (__MUSA_ARCH__ >= 220 || !defined(__MUSA_ARCH__)) +DINLINE float upcast_s(__mt_bfloat16 val) { + return __bfloat162float(val); +} +template <> +DINLINE __mt_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); +} +DINLINE __mt_bfloat16& assign_add(__mt_bfloat16& a, __mt_bfloat16 b) { + a = __hadd(a, b); + return a; +} +#endif + +template +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { +#pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); + } + return a; +} + +template +DINLINE array_t upcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + array_t out; +#pragma unroll + for (int i = 0; i < N; i++) { + out.data[i] = upcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE O downcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + O out; +#pragma unroll + for (int i = 0; i < O::size; i++) { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE void _downcast(T *out, float* val) { +#pragma unroll + for (int32_t i = 0; i < vlen; i++) { + out[i] = downcast_s(val[i]); + } +} + +static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { +#ifdef USE_MUSA + volatile_store((uint32_t)flag, (uint32_t*)flag_addr); +#elif defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 700 + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#else + asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#endif +} + +static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { +#ifdef USE_MUSA + flushInv_byp(); + return (uint32_t)volatile_load((uint32_t*)flag_addr); +#endif + + FlagType flag; +#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 700 + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); +#else + asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" : "=r"(flag) : "l"(flag_addr)); +#endif + return flag; +} + +static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { +#ifdef USE_MUSA + volatile FlagType* volatile_ptr = (volatile FlagType*)flag_addr; + *volatile_ptr = flag; +#else + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#endif +} + +static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { +#ifdef USE_MUSA + volatile FlagType* volatile_ptr = (volatile FlagType*)flag_addr; + return *volatile_ptr; +#endif + + FlagType flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + return flag; +} + +// is_start: whether this is the very first synchronization barrier. +// need_fence: whether a memory fence is needed. If true, a release-acquire +// semantic is used to enforce memory access order before and after this +// barrier. +template +DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, int rank) { + if constexpr (!is_start) +#ifdef USE_MUSA + __syncthreads_lm(); +#else + __syncthreads(); +#endif + static_assert(!(is_start && need_fence)); // Start barrier shouldn't need fence. + if (threadIdx.x < ngpus) { + // Increment the counter. Technically we only need one counter, but we use + // multiple per block to eliminate the need to share the counter via smem. + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + // Write the expected counter value to peer and wait for correct value from + // peer. + auto peer_counter_ptr = &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + if constexpr (need_fence) { + st_flag_release(peer_counter_ptr, val); + while (ld_flag_acquire(self_counter_ptr) != val) + ; + } else { + st_flag_volatile(peer_counter_ptr, val); + while (ld_flag_volatile(self_counter_ptr) != val) + ; + } + } + if constexpr (is_start || need_fence) +#ifdef USE_MUSA + __syncthreads_lm(); +#else + __syncthreads(); +#endif +} + +template +DINLINE void multi_gpu_barrier_with_atomic(const RankSignals& sg, Signal* self_sg, int32_t rank) { + if (threadIdx.x < ngpus) { + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + auto peer_counter_ptr = &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + atomicExch(peer_counter_ptr, val); + while (atomicAdd(self_counter_ptr, 0) != val) { + } + } + __syncthreads_lm(); +} + +template +DINLINE P packed_reduce(const P* ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + return downcast

(tmp); +} + +template +__global__ void __launch_bounds__(kMaxThreadsPerBlock, 1) cross_device_reduce_1stage( + RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; + multi_gpu_barrier(sg, self_sg, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); + } + multi_gpu_barrier(sg, self_sg, rank); +} + +template +DINLINE P* get_tmp_buf(Signal* sg) { + return (P*)(((Signal*)sg) + 1); +} + +template +DINLINE void shfl_reduce(float *res) { + if constexpr (nranks >= 4) { +#pragma unroll + for (int32_t i = 0; i < vlen; i++) { + res[i] += __shfl_xor_sync(0xffffffff, res[i], 16); + } + } +#pragma unroll + for (int32_t i = 0; i < vlen; i++) { + res[i] += __shfl_xor_sync(0xffffffff, res[i], 8); + } +} + +template +__global__ void __launch_bounds__(kMaxThreadsPerBlock, 1) custom_all_reduce_2shot( + RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int32_t local_rank, int32_t size, FlagType round) { + constexpr int32_t nranks_sft = (nranks >> 1) - (nranks >> 3); // 8->3, 4->2, 2->1 + constexpr int32_t coalesce_num = 8; + constexpr int32_t coalesce_sft = 3; // 8 threads per rank in group + constexpr int32_t group_size = nranks << coalesce_sft; // 64 threads per group when 8 ranks + constexpr int32_t group_stride_sft = nranks_sft + coalesce_sft; + const int32_t tidx = threadIdx.x; + const int32_t bidx = blockIdx.x; + const int32_t thread_num = blockDim.x; + const int32_t lane_idx = tidx & 31; + const int32_t warp_idx = tidx >> 5; + const int32_t group_num = thread_num >> group_stride_sft; + const int32_t target_rank = (tidx >> coalesce_sft) & (nranks - 1); + const int32_t group_id = tidx >> group_stride_sft; + const int32_t coalesce_tid = tidx & (coalesce_num - 1); + + typedef int16_t Vec __attribute__((vector_size(16))); + + const int32_t stride = gridDim.x * thread_num; + // coalesce_id + local_rank * coalesce_num + group_id * nranks * coalesce_num + // + bidx * group_num * nranks * coalesce_num + int32_t idx_base = bidx * thread_num; + int32_t idx_in_blk = coalesce_tid + (local_rank << coalesce_sft) + (group_id << group_stride_sft); + + + // first sync barrier + FlagType *target_barrier = nullptr; + FlagType *local_barrier = nullptr; + FlagType flag; + if (tidx < nranks) { + flag = round << 1; + target_barrier = &sg.signals[tidx]->peer_counter[flag & 1][bidx][local_rank]; + local_barrier = &self_sg->peer_counter[flag & 1][bidx][tidx]; + atomicExch(target_barrier, flag); + while (atomicAdd(local_barrier, 0) != flag) {} + } + __syncthreads_lm(); + + // reduce scatter + Vec* target_ptr = (Vec*)_dp->ptrs[target_rank]; + Vec* buffer_ptr = get_tmp_buf(sg.signals[local_rank]); + do { + int32_t idx = idx_in_blk + idx_base; + float temp_res[vlen] = {0}; + if (idx < size){ + T* data = reinterpret_cast(&(target_ptr[idx])); +#pragma unroll + for (int32_t i = 0; i < vlen; i++) { + temp_res[i] = upcast_s(data[i]); + } + } + shfl_reduce(temp_res); + // reduce cross warp + if constexpr (nranks == 8) { + __shared__ float smem[kMaxThreadsPerBlock << 1]; + if (lane_idx < coalesce_num) { +#pragma unroll + for (int32_t i = 0; i < vlen; i++) { + smem[warp_idx * vlen * coalesce_num + coalesce_tid * vlen + i] = temp_res[i]; + } + } + __syncthreads_lm(); +#pragma unroll + for (int32_t i = 0; i < vlen; i++) { + temp_res[i] += smem[(warp_idx ^ 1) * vlen * coalesce_num + coalesce_tid * vlen + i]; + } + } + + if (local_rank == target_rank && idx < size) { + Vec res; +#pragma unroll + for (int32_t i = 0; i < vlen; i++) { + reinterpret_cast(&res)[i] = downcast_s(temp_res[i]); + } + buffer_ptr[idx] = res; + } + idx_base += stride; + } while(idx_base < size); + // make sure buffer_ptr data ready + __musa_barrier_slc(); + __syncthreads_lm(); + if (tidx == 0) { + __threadfence_system_noflush(); + } + buffer_ptr = get_tmp_buf(sg.signals[target_rank]); + // second sync barrier + if (tidx < nranks) { + flag++; + target_barrier = &sg.signals[tidx]->peer_counter[flag & 1][bidx][local_rank]; + local_barrier = &self_sg->peer_counter[flag & 1][bidx][tidx]; + atomicExch(target_barrier, flag); + while (atomicAdd(local_barrier, 0) != flag) {} + } + __syncthreads_lm(); + + // all gather + idx_in_blk = coalesce_tid + (target_rank << coalesce_sft) + (group_id << group_stride_sft); + idx_base = bidx * thread_num; + do { + int32_t idx = idx_in_blk + idx_base; + if (idx < size) { + reinterpret_cast(result)[idx] = buffer_ptr[idx]; + } + idx_base += stride; + } while (idx_base < size); +} + +template +__global__ void __launch_bounds__(kMaxThreadsPerBlock, 1) cross_device_reduce_2stage( + RankData* _dp, RankSignals sg, Signal* self_sg, T* __restrict__ result, int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + int largest_part = part + size % ngpus; + const P* ptrs[ngpus]; + P* tmps[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; + multi_gpu_barrier(sg, self_sg, rank); + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + tmp_out[idx - start] = packed_reduce(ptrs, idx); + } + multi_gpu_barrier(sg, self_sg, rank); + + // stage 2: allgather. Note: it's important to match the tid between + // the two stages, because visibility across devices is only guaranteed + // between threads that have the same tid. If thread i computes the sum of + // start + i in the first stage, then thread i also gathers start + i from all + // ranks. + for (int idx = tid; idx < largest_part; idx += stride) { +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int gather_from_rank = ((rank + i) % ngpus); + if (gather_from_rank == ngpus - 1 || idx < part) { + int dst_idx = gather_from_rank * part + idx; + ((P*)result)[dst_idx] = tmps[i][idx]; + } + } + } +} + +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(musaIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(musaIpcMemHandle_t)); + +class CustomAllreduce { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + RankSignals sg_; + // Stores an map from a pointer to its peer pointters from all ranks. + std::unordered_map buffers_; + Signal* self_sg_; + FlagType round; + + // Stores rank data from all ranks. This is mainly for musa graph purposes. + // For musa graph to work, all kernel arguments must be fixed during graph + // capture time. However, the peer pointers are not known during graph capture + // time. Therefore, during capture, we increment the rank data pointer and use + // that as the argument to the kernel. The kernel arguments are stored in + // graph_unreg_buffers_. The actual peer pointers will be filled in at the + // memory pointed to by the pointers in graph_unreg_buffers_ when + // the IPC handles are exchanged between ranks. + // + // The overall process looks like this: + // 1. Graph capture. + // 2. Each rank obtains the IPC handles for each addresses used during musa + // graph capture using get_graph_buffer_ipc_meta. + // 3. (In Python) all gather the IPC handles. + // 4. Obtain the peer pointers by opening the IPC handles, and store them in + // the rank data array at corresponding positions. + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + // a map from IPC handles to opened IPC pointers + std::map ipc_handles_; + + /** + * Signals are an array of ipc-enabled buffers from all ranks. + * For each of the buffer, the layout is as follows: + * | -- sizeof(Signal) -- | ------ a few MB ----- | + * The first section is for allreduce synchronization, and the second section + * is for storing the intermediate results required by some allreduce algos. + * + * Note: this class does not own any device memory. Any required buffers + * are passed in from the constructor. + */ + CustomAllreduce( + Signal** signals, void* rank_data, size_t rank_data_sz, int rank, int world_size, bool full_nvlink = true) + : rank_(rank), + world_size_(world_size), + full_nvlink_(full_nvlink), + self_sg_(signals[rank]), + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)), + round(0) { + for (int i = 0; i < world_size_; i++) { + sg_.signals[i] = signals[i]; + } + + } + + char* open_ipc_handle(const void* ipc_handle) { + auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CHECK_MUSA_SUCCESS(musaIpcOpenMemHandle( + (void**)&ipc_ptr, *((const musaIpcMemHandle_t*)ipc_handle), musaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + + std::pair> get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(musaIpcMemHandle_t); + std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (muPointerGetAttribute(&base_ptr, MU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (MUdeviceptr)ptr) != MUSA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CHECK_MUSA_SUCCESS(musaIpcGetMemHandle((musaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + return std::make_pair(handles, offsets); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error( + "Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + /** + * Register already-shared IPC pointers. + */ + void register_buffer(void** ptrs) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + data.ptrs[i] = ptrs[i]; + } + auto d_data = d_rank_data_base_++; + CHECK_MUSA_SUCCESS(musaMemcpy(d_data, &data, sizeof(RankData), musaMemcpyHostToDevice)); + buffers_[ptrs[rank_]] = d_data; + } + + // Note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allreduce, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void + register_graph_buffers(const std::vector& handles, const std::vector>& offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char* handle = open_ipc_handle(&handles[j][i * sizeof(musaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CHECK_MUSA_SUCCESS( + musaMemcpy(d_rank_data_base_, rank_data.data(), sizeof(RankData) * num_buffers, musaMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; + graph_unreg_buffers_.clear(); + } + + /** + * Performs allreduce, assuming input has already been registered. + * + * Block and grid default configs are results after careful grid search. Using + * 36 blocks give the best or close to the best runtime on the devices I + * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only + * take a small amount of SMs. Not quite sure the underlying reason, but my + * guess is that too many SMs will cause contention on NVLink bus. + */ + template + void allreduce(musaStream_t stream, T* input, T* output, int size, + int threads = kDefaultThreads, int block_limit = kDefaultBlockLimit) { + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error( + "max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); + + RankData* ptrs; + musaStreamCaptureStatus status; + CHECK_MUSA_SUCCESS(musaStreamIsCapturing(stream, &status)); + if (status == musaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + std::to_string(reinterpret_cast(input)) + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); + +#define KL(ngpus, name) name<<>>(ptrs, sg_, self_sg_, output, rank_, size); + // TODO(hanzhi713): Threshold is different for A100 and H100. + // Add per device threshold. +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (full_nvlink_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + case 4: + if constexpr (!std::is_same::value) { + custom_all_reduce_2shot<<>>(ptrs, + sg_, self_sg_, output, rank_, size, ++round); + } else { + if (bytes < 256 * 1024) { + KL(4, cross_device_reduce_1stage); + } else { + KL(4, cross_device_reduce_2stage); + } + } + break; + REDUCE_CASE(6) + case 8: + if constexpr (!std::is_same::value) { + custom_all_reduce_2shot<<>>(ptrs, + sg_, self_sg_, output, rank_, size, ++round); + } else { + if (bytes < 256 * 1024) { + KL(8, cross_device_reduce_1stage); + } else { + KL(8, cross_device_reduce_2stage); + } + } + break; + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + + ~CustomAllreduce() { + for (auto [_, ptr] : ipc_handles_) { + CHECK_MUSA_SUCCESS(musaIpcCloseMemHandle(ptr)); + } + } +}; +/** + * To inspect PTX/SASS, copy paste this header file to compiler explorer and add + a template instantiation: + * template void sglang::CustomAllreduce::allreduce(musaStream_t, half *, + half *, int, int, int); +*/ +} // namespace sglang diff --git a/src/infiniccl/moore/infiniccl_moore.cc b/src/infiniccl/moore/infiniccl_moore.cc index 4dee4aebd..83c809723 100644 --- a/src/infiniccl/moore/infiniccl_moore.cc +++ b/src/infiniccl/moore/infiniccl_moore.cc @@ -1,4 +1,5 @@ #include "infiniccl_moore.h" +#include "custom_all_reduce.h" #include "../../utils.h" @@ -7,6 +8,9 @@ #include #include +#include +#include +#include #define CHECK_MCCL(API__) CHECK_INTERNAL(API__, mcclSuccess) @@ -23,8 +27,12 @@ inline mcclDataType_t getMcclDtype(infiniDtype_t datatype) { return mcclFloat; case INFINI_DTYPE_F16: return mcclHalf; + +#if MARCH_TYPE == 310 case INFINI_DTYPE_BF16: return mcclBfloat16; +#endif + default: std::abort(); return mcclHalf; @@ -50,7 +58,8 @@ inline mcclRedOp_t getMcclRedOp(infinicclReduceOp_t op) { } inline mcclComm_t getMcclComm(infinicclComm_t comm) { - return static_cast(comm->comm); + CustomAllReduceComm* customComm = static_cast(comm->comm); + return static_cast(customComm->acomm); } namespace infiniccl::moore { @@ -63,15 +72,26 @@ infiniStatus_t commInitAll( std::vector mccl_comms(ndevice); CHECK_MCCL(mcclCommInitAll(mccl_comms.data(), ndevice, (int const *)device_ids)); + std::vector> futures; + futures.reserve(ndevice); + for (int i = 0; i < ndevice; i++) { - comms[i] = new InfinicclComm{INFINI_DEVICE_MOORE, device_ids[i], (void *)(mccl_comms[i])}; + futures.emplace_back(std::async(std::launch::async, [i, ndevice, &mccl_comms]() { + return new CustomAllReduceComm(i, ndevice, mccl_comms[i]); + })); } + for (int i = 0; i < ndevice; i++) { + auto ca = futures[i].get(); + comms[i] = new InfinicclComm{INFINI_DEVICE_MOORE, device_ids[i], (void *)(ca)}; + } return INFINI_STATUS_SUCCESS; } infiniStatus_t commDestroy(infinicclComm_t comm) { CHECK_MCCL(mcclCommDestroy(getMcclComm(comm))); + // CustomAllReduceComm* customComm = static_cast(comm->comm); + // delete customComm; delete comm; return INFINI_STATUS_SUCCESS; } @@ -85,11 +105,160 @@ infiniStatus_t allReduce( infinicclComm_t comm, infinirtStream_t stream) { - CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16); + if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16 && datatype != INFINI_DTYPE_BF16) { + return INFINI_STATUS_BAD_PARAM; + } - CHECK_MCCL(mcclAllReduce(sendbuf, recvbuf, count, getMcclDtype(datatype), + CustomAllReduceComm* customComm = static_cast(comm->comm); + if (customComm->should_custom_ar(sendbuf, count, datatype) == false) { + // std::cout << "using mccl" << std::endl; + // std::cout << "using mccl count: " << count << ", datatype: " << getMcclDtype(datatype) << std::endl; + CHECK_MCCL(mcclAllReduce(sendbuf, recvbuf, count, getMcclDtype(datatype), getMcclRedOp(op), getMcclComm(comm), getMusaStream(stream))); + // std::cout << "end cmccl" << std::endl; + } else { + // std::cout << "using custom all reduce" << std::endl; + // std::cout << "using mccl count: " << count << ", datatype: " << getMcclDtype(datatype) << std::endl; + auto rank = customComm->crank; + all_reduce(customComm->custom_ptr, sendbuf, recvbuf, count, getMcclDtype(datatype), customComm->buffer_ptrs[rank], customComm->max_size, getMusaStream(stream)); + musaError_t e = musaStreamSynchronize(getMusaStream(stream)); + if (e != musaSuccess) { + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, + musaGetErrorString(e)); + exit(EXIT_FAILURE); + } + // std::cout << "end custom all reduce" << std::endl; + } return INFINI_STATUS_SUCCESS; } } // namespace infiniccl::moore + +#define CHECK_MUSA_SUCCESS(cmd) \ + do { \ + musaError_t e = cmd; \ + if (e != musaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + musaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + + + +std::vector create_shared_buffer(int64_t size_in_bytes, int ndev, int rank, void* comm) { + musaStream_t stream; + CHECK_MUSA_SUCCESS(musaStreamCreateWithFlags(&stream, musaStreamNonBlocking)); + + void* pointer = nullptr; + CHECK_MUSA_SUCCESS(musaMalloc(&pointer, size_in_bytes)); + CHECK_MUSA_SUCCESS(musaMemset(pointer, 0, size_in_bytes)); + + musaIpcMemHandle_t handle; + CHECK_MUSA_SUCCESS(musaIpcGetMemHandle(&handle, pointer)); + + size_t handle_size = sizeof(musaIpcMemHandle_t); + + void* input_tensor = nullptr; + void* recv_buffer = nullptr; + CHECK_MUSA_SUCCESS(musaMalloc(&input_tensor, handle_size)); + CHECK_MUSA_SUCCESS(musaMalloc(&recv_buffer, handle_size * ndev)); + CHECK_MUSA_SUCCESS(musaMemcpyAsync(input_tensor, &handle, handle_size, musaMemcpyHostToDevice, stream)); + + mcclResult_t e = mcclAllGather(input_tensor, recv_buffer, handle_size, mcclUint8, static_cast(comm), stream); + if (e != mcclSuccess) { + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, + mcclGetErrorString(e)); + exit(EXIT_FAILURE); + } + + CHECK_MUSA_SUCCESS(musaStreamSynchronize(stream)); + + musaIpcMemHandle_t* handles = new musaIpcMemHandle_t[ndev]; + musaMemcpy(handles, recv_buffer, handle_size * ndev, musaMemcpyDeviceToHost); + + std::vector pointers; + for (int i = 0; i < ndev; ++i) { + if (i == rank) { + pointers.push_back(pointer); + } else { + void* remote_ptr = nullptr; + CHECK_MUSA_SUCCESS(musaIpcOpenMemHandle(&remote_ptr, handles[i], musaIpcMemLazyEnablePeerAccess)); + pointers.push_back(remote_ptr); + } + } + + std::vector int_pointers; + + int_pointers.resize(pointers.size()); + std::transform(pointers.begin(), pointers.end(), int_pointers.begin(), + [](void* ptr) -> int64_t { + return reinterpret_cast(ptr); + } + ); + + musaStreamDestroy(stream); + // musaFree(pointer); + // musaFree(input_tensor); + // musaFree(recv_buffer); + return int_pointers; +} + +void free_shared_buffer(std::vector pointers, int rank) { + void* pointer = reinterpret_cast(pointers[rank]); + musaFree(pointer); +} + +bool CustomAllReduceComm::should_custom_ar(void* inp, size_t count, infiniDtype_t datatype) { + if (!use_custom_all_reduce) { + return false; + } + + if (count % 16 != 0) { + return false; + } + + return (count * 2) < max_size; +} + +CustomAllReduceComm::CustomAllReduceComm(int64_t rank, int ndev, void* comm) { + CHECK_MUSA_SUCCESS(musaSetDevice(rank)); + + devices = ndev; + crank = rank; + acomm = comm; + auto it = std::find( + support_world_sizes.begin(), + support_world_sizes.end(), + devices + ); + + if (it == support_world_sizes.end()) { + use_custom_all_reduce = false; + return; + } + + int64_t metasize = meta_size(); + meta_ptrs = create_shared_buffer( + metasize + max_size, devices, crank, acomm + ); + + size_t num_elements = 8 * 1024 * 1024; + size_t element_size = sizeof(uint8_t); + size_t total_bytes = num_elements * element_size; + void* rank_data = nullptr; + CHECK_MUSA_SUCCESS(musaMalloc(&rank_data, total_bytes)); + buffer_ptrs = create_shared_buffer(max_size, devices, crank, acomm); + + custom_ptr = init_custom_ar(meta_ptrs, rank_data, num_elements, crank, true); + register_buffer(custom_ptr, buffer_ptrs); +} + +// CustomAllReduceComm::~CustomAllReduceComm() { +// musaFree(rank_data); +// for (int i = 0; i < devices; ++i) { +// free_shared_buffer(meta_ptrs, i); +// free_shared_buffer(buffer_ptrs, i); +// } +// } + diff --git a/src/infiniccl/moore/infiniccl_moore.h b/src/infiniccl/moore/infiniccl_moore.h index 318fc468b..54e174dce 100644 --- a/src/infiniccl/moore/infiniccl_moore.h +++ b/src/infiniccl/moore/infiniccl_moore.h @@ -2,6 +2,8 @@ #define INFINICCL_MOORE_H_ #include "../infiniccl_impl.h" +#include +#include #if defined(ENABLE_MOORE_API) && defined(ENABLE_CCL) INFINICCL_DEVICE_API_IMPL(moore) @@ -10,3 +12,21 @@ INFINICCL_DEVICE_API_NOOP(moore) #endif #endif /* INFINICCL_MOORE_H_ */ + +struct CustomAllReduceComm { + const size_t max_size = 128 * 1024 * 1024; + const std::vector support_world_sizes = {2,8}; + + void *acomm; + int devices; + int64_t crank; + bool use_custom_all_reduce = true; + std::vector meta_ptrs; + std::vector buffer_ptrs; + void* rank_data; + int64_t custom_ptr; + + CustomAllReduceComm(int64_t rank, int ndev, void* group); + ~CustomAllReduceComm(); + bool should_custom_ar(void* inp, size_t count, infiniDtype_t datatype); +}; \ No newline at end of file diff --git a/src/infiniccl/moore/utils.h b/src/infiniccl/moore/utils.h new file mode 100644 index 000000000..8da9b4012 --- /dev/null +++ b/src/infiniccl/moore/utils.h @@ -0,0 +1,481 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#ifdef USE_MUSA +#include +#define CU_POINTER_ATTRIBUTE_RANGE_START_ADDR MU_POINTER_ATTRIBUTE_RANGE_START_ADDR +#define CUdeviceptr MUdeviceptr +#define CUDA_SUCCESS MUSA_SUCCESS +#define cuPointerGetAttribute muPointerGetAttribute +#define cudaDevAttrComputeCapabilityMajor musaDevAttrComputeCapabilityMajor +#define cudaDevAttrComputeCapabilityMinor musaDevAttrComputeCapabilityMinor +#define cudaDeviceGetAttribute musaDeviceGetAttribute +#define cudaDeviceProp musaDeviceProp +#define cudaError_t musaError_t +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaSuccess musaSuccess +using __nv_bfloat16 = __mt_bfloat16; +using __nv_bfloat162 = __mt_bfloat162; +using nv_bfloat16 = __mt_bfloat16; +using nv_bfloat162 = __mt_bfloat162; +using nv_half = __half; + +#endif +#include + +#ifdef USE_ROCM +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = __half; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = __hip_bfloat16; \ + return __VA_ARGS__(); \ + } +#endif // USE_ROCM + +#ifndef USE_ROCM +// Adapt from FlashInfer +#ifdef FLASHINFER_ENABLE_F16 +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_F16(c_type, ...) +#endif // FLASHINFER_ENABLE_F16 + +#ifdef FLASHINFER_ENABLE_BF16 +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_BF16(c_type, ...) +#endif // FLASHINFER_ENABLE_BF16 + +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) +#endif // FLASHINFER_ENABLE_FP8_E4M3 + +#ifdef FLASHINFER_ENABLE_FP8_E5M2 +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) +#endif // FLASHINFER_ENABLE_FP8_E5M2 + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + [&]() -> bool { \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ + [&]() -> bool { \ + switch (pack_u16(cond1, cond2)) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \ + << int(cond2) << ")"; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \ + case pack_u16(case_expr1, case_expr2): { \ + constexpr auto case_var1 = case_expr1; \ + constexpr auto case_var2 = case_expr2; \ + return __VA_ARGS__(); \ + } + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + +inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ + TORCH_CHECK( \ + num_qo_heads % num_kv_heads == 0, \ + "num_qo_heads(", \ + num_qo_heads, \ + ") must be divisible by num_kv_heads(", \ + num_kv_heads, \ + ")") + +#ifdef USE_MUSA +#define CHECK_CUDA(x) TORCH_CHECK(true, #x " must be a CUDA tensor") +#else +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#endif + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) + +inline bool is_float8_tensor(const at::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; +} +#endif // USE_ROCM + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + std::stringstream _message; \ + auto s = cudaGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) + +inline int getSMVersion() { + int device{-1}; + CHECK_CUDA_SUCCESS(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline bool isDeviceType(const std::string& device_type) { + int deviceCount; + CHECK_CUDA_SUCCESS(cudaGetDeviceCount(&deviceCount)); + + int device_id = -1; + if (deviceCount >= 1) { + CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id)); + } else { + return false; + } + + cudaDeviceProp prop; + CHECK_CUDA_SUCCESS(cudaGetDeviceProperties(&prop, device_id)); + if (device_type == std::string(prop.name)) { + return true; + } + return false; +} + +inline bool getBoolEnv(char const* name) { + char const* env = std::getenv(name); + return env && env[0] == '1' && env[1] == '\0'; +} + +inline bool getEnvEnablePDL() { + static std::once_flag flag; + static bool enablePDL = false; + std::call_once(flag, [&]() { + if (getSMVersion() >= 90) { + // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` + enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); + } + }); + return enablePDL; +} + +// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28 +#ifndef USE_ROCM +#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask)) +#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width)) +#else +#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask)) +#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width)) +#endif + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +#if !defined(USE_ROCM) && !defined(USE_MUSA) +# define WARP_SIZE 32 +#elif defined(USE_ROCM) +# if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) +# define WARP_SIZE 64 +# else +# define WARP_SIZE 32 +#endif +#else +# if defined(__MUSA_ARCH__) && __MUSA_ARCH__ <= 220 +# define WARP_SIZE 128 +# else +# define WARP_SIZE 32 +# endif +#endif + +#ifdef USE_ROCM + +#include "hip/hip_math_def.h" +#include "hip/hip_vec_dtypes.h" + +#else + +template +__device__ __forceinline__ float castToFloat(srcDtype val) { + return static_cast(val); +} + +template +__device__ __forceinline__ dstDtype castFromFloat(float val) { + return static_cast(val); +} + +#endif + +// add FP8 support +#ifndef USE_ROCM +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else // USE_ROCM +#if HIP_FP8_TYPE_FNUZ +#include +using FP8_TYPE = c10::Float8_e4m3fnuz; +constexpr auto FP8_E4M3_MAX = 224.0f; +#else +#if HIP_FP8_TYPE_E4M3 +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#error "fp8 is not supported in this processor (arch < gfx942)." +#endif // HIP_FP8_TYPE_E4M3 +#endif // HIP_FP8_TYPE_FNUZ +#endif // USE_ROCM + +#define FULL_MASK 0xffffffff + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { +#ifndef USE_ROCM + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif +} + +__device__ __forceinline__ float warpReduceMax(float value) { + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1)); + return value; +} + +__device__ __forceinline__ float blockReduceMax(float value) { + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + value = warpReduceMax(value); + + if (laneId == 0) warpLevelMaxs[warpId] = value; + __syncthreads(); + + value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) value = warpReduceMax(value); + + return value; +} + +// Pads to a multiple of `alignment` rows. +inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) { + int64_t rows = tensor.size(0); + int64_t cols = tensor.size(1); + int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size + + if (pad_rows == 0) { + return tensor; // Already aligned + } + + torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options()); + torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows + + // Ensure column-major layout + if (is_column_major) { + return tensor_padded.t().contiguous().t(); + } + return tensor_padded; +} + +// Get the next power of 2 of a number +inline uint32_t next_pow2(uint32_t x) noexcept { + if (x <= 1) return 1; + return 1u << (32 - __builtin_clz(x - 1)); +} diff --git a/xmake/moore.lua b/xmake/moore.lua index 908a332cc..7c5bcdbbf 100644 --- a/xmake/moore.lua +++ b/xmake/moore.lua @@ -16,7 +16,7 @@ rule("mu") local mcc = MUSA_ROOT .. "/bin/mcc" local includedirs = table.concat(target:get("includedirs"), " ") - local args = {"--cuda-gpu-arch=mp_31", "-c", sourcefile, "-o", objectfile, "-I" .. MUSA_ROOT .. "/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"} + local args = {"--cuda-gpu-arch=mp_31", "-DUSE_MUSA", "-c", sourcefile, "-o", objectfile, "-I" .. MUSA_ROOT .. "/include", "-O3", "-fPIC", "-Wall", "-std=c++17", "-pthread"} for _, includedir in ipairs(target:get("includedirs")) do table.insert(args, "-I" .. includedir) end @@ -68,16 +68,22 @@ target("infiniccl-moore") set_kind("static") add_deps("infinirt") on_install(function (target) end) + set_languages("cxx17") set_warnings("all", "error") if not is_plat("windows") then - add_cxflags("-fPIC") + add_cxflags("-lstdc++", "-fPIC", "-Wno-comment") add_cxxflags("-fPIC") end if has_config("ccl") then add_links("libmccl.so") add_files("../src/infiniccl/moore/*.cc") - add_defines("MARCH_TYPE=310") - add_cxxflags("-Wno-unused-function") + add_files("../src/infiniccl/moore/*.mu", {rule = "mu"}) + + -- Moore GPU arch with mp_31 support mcclBfloat16 in MCCL + if get_config("moore-gpu-arch") == "mp_31" then + add_defines("MARCH_TYPE=310") + add_cxxflags("-Wno-unused-function") + end end set_languages("cxx17")