Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 98 additions & 66 deletions python/sglang/srt/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,42 @@

import torch

from sglang.srt.utils import is_hip, is_hpu, is_npu
from sglang.srt.utils import is_cuda, is_hip

logger = logging.getLogger(__name__)

_is_cuda = is_cuda()
_is_hip = is_hip()

if not is_hpu():
try:
import sgl_kernel
except ImportError as e:
IS_CUSTOM_AR_AVAILABLE = _is_cuda and _is_hip
IS_QUICK_AR_AVAILABLE = _is_hip
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
IS_MSCCLPP_AR_AVAILABLE = _is_cuda

try:
import sgl_kernel.allreduce as _custom_ar
except ImportError as e:
if _is_cuda or _is_hip:
logger.warning("Failed to import from custom_ar with %r", e)
IS_CUSTOM_AR_AVAILABLE = False
IS_QUICK_AR_AVAILABLE = False
IS_MSCCLPP_AR_AVAILABLE = False

# region IS_CUSTOM_AR_AVAILABLE

if not IS_CUSTOM_AR_AVAILABLE:
pass

if not is_hip() and not is_npu():
custom_op = sgl_kernel.allreduce
elif _is_cuda:
# CUDA custom allreduce

# custom allreduce
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
rank_data: torch.Tensor,
rank: int,
full_nvlink: bool,
) -> int:
return custom_op.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)
return _custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, full_nvlink)

def all_reduce(
fa: int,
Expand All @@ -35,26 +48,26 @@ def all_reduce(
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
custom_op.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)

def dispose(fa: int) -> None:
custom_op.dispose(fa)
_custom_ar.dispose(fa)

def meta_size() -> int:
return custom_op.meta_size()
return _custom_ar.meta_size()

def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return custom_op.register_buffer(fa, ipc_tensors)
return _custom_ar.register_buffer(fa, ipc_tensors)

def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return custom_op.get_graph_buffer_ipc_meta(fa)
return _custom_ar.get_graph_buffer_ipc_meta(fa)

def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
custom_op.register_graph_buffers(fa, handles, offsets)
_custom_ar.register_graph_buffers(fa, handles, offsets)

else:
elif _is_hip:
# ROCM custom allreduce

def init_custom_ar(
Expand All @@ -65,55 +78,64 @@ def init_custom_ar(
rank: int,
full_nvlink: bool,
) -> int:
return sgl_kernel.allreduce.init_custom_ar(
return _custom_ar.init_custom_ar(
meta, rank_data, handles, offsets, rank, full_nvlink
)

def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
sgl_kernel.allreduce.all_reduce_reg(fa, inp, out)
_custom_ar.all_reduce_reg(fa, inp, out)

def all_reduce_unreg(
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
) -> None:
sgl_kernel.allreduce.all_reduce_unreg(fa, inp, reg_buffer, out)
_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)

def dispose(fa: int) -> None:
sgl_kernel.allreduce.dispose(fa)
_custom_ar.dispose(fa)

def meta_size() -> int:
return sgl_kernel.allreduce.meta_size()
return _custom_ar.meta_size()

def register_buffer(
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
) -> None:
return sgl_kernel.allreduce.register_buffer(fa, t, handles, offsets)
return _custom_ar.register_buffer(fa, t, handles, offsets)

def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
return sgl_kernel.allreduce.get_graph_buffer_ipc_meta(fa)
return _custom_ar.get_graph_buffer_ipc_meta(fa)

def register_graph_buffers(
fa: int, handles: List[str], offsets: List[List[int]]
) -> None:
sgl_kernel.allreduce.register_graph_buffers(fa, handles, offsets)
_custom_ar.register_graph_buffers(fa, handles, offsets)

def allocate_meta_buffer(size: int) -> torch.Tensor:
return sgl_kernel.allreduce.allocate_meta_buffer(size)
return _custom_ar.allocate_meta_buffer(size)

def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
return _custom_ar.get_meta_buffer_ipc_handle(inp)


# endregion

# region IS_QUICK_AR_AVAILABLE

if not IS_QUICK_AR_AVAILABLE:
pass

elif _is_hip:
# ROCM custom quick allreduce

def init_custom_qr(
rank: int, world_size: int, qr_max_size: Optional[int] = None
) -> int:
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
return _custom_ar.init_custom_qr(world_size, rank, qr_max_size)

def qr_get_handle(fa: int) -> torch.Tensor:
return sgl_kernel.allreduce.qr_get_handle(fa)
return _custom_ar.qr_get_handle(fa)

def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
sgl_kernel.allreduce.qr_open_handles(fa, handles)
_custom_ar.qr_open_handles(fa, handles)

def qr_all_reduce(
fa: int,
Expand All @@ -122,44 +144,54 @@ def qr_all_reduce(
quant_level: int,
cast_bf2half: bool,
) -> None:
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
_custom_ar.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)

def qr_destroy(fa: int) -> None:
sgl_kernel.allreduce.qr_destroy(fa)
_custom_ar.qr_destroy(fa)

def qr_max_size() -> int:
return sgl_kernel.allreduce.qr_max_size()


def mscclpp_generate_unique_id() -> bytes:
return sgl_kernel.allreduce.mscclpp_generate_unique_id()


def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return sgl_kernel.allreduce.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)


def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
return _custom_ar.qr_max_size()


# endregion

# region IS_MSCCLPP_AR_AVAILABLE

if not IS_MSCCLPP_AR_AVAILABLE:
pass

elif _is_cuda:

def mscclpp_generate_unique_id() -> bytes:
return _custom_ar.mscclpp_generate_unique_id()

def mscclpp_init_context(
unique_id: bytes,
rank: int,
world_size: int,
scratch: torch.Tensor,
put_buffer: torch.Tensor,
nranks_per_node: int,
rank_to_node: List[int],
rank_to_ib: List[int],
context_selection: int,
) -> int:
return _custom_ar.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
context_selection,
)

def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
return _custom_ar.mscclpp_allreduce(context, inp, out, nthreads, nblocks)


# endregion
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@
from sglang.srt.environ import envs
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, log_info_on_rank0

try:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401

custom_ar = True
except ImportError:
# For CPUs
custom_ar = False


_is_cuda = is_cuda()
_is_hip = is_hip()

Expand Down Expand Up @@ -78,7 +68,7 @@ def __init__(
self._IS_CAPTURING = False
self.disabled = True

if not custom_ar:
if not ops.IS_CUSTOM_AR_AVAILABLE:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
Expand Down
17 changes: 2 additions & 15 deletions python/sglang/srt/distributed/device_communicators/pymscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,12 @@
from torch.distributed import ProcessGroup, ReduceOp

from sglang.srt import _custom_ops as ops
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import is_hip

logger = logging.getLogger(__name__)

_is_cuda = is_cuda()
_is_hip = is_hip()

mscclpp_is_available = False
if _is_hip:
# TODO(zyksir): mscclpp is untested on AMD and therefore disabled.
mscclpp_is_available = False
if _is_cuda:
try:
import sgl_kernel # noqa: F401

mscclpp_is_available = True
except:
mscclpp_is_available = False


class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
Expand Down Expand Up @@ -127,7 +114,7 @@ def __init__(
self._IS_CAPTURING = False
self.disabled = True

if not mscclpp_is_available:
if not ops.IS_MSCCLPP_AR_AVAILABLE:
# disable because of missing mscclpp library
# e.g. in a non-cuda environment
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@
_is_hip = is_hip()


try:
ops.qr_max_size()
quick_ar = True
except Exception:
# For CPUs and CUDA
quick_ar = False


@cache
def qr_rocm_arch_available():
if not _is_hip:
Expand Down Expand Up @@ -101,7 +93,7 @@ def __init__(
)
return

if not quick_ar:
if not ops.IS_QUICK_AR_AVAILABLE:
# disable because of missing quick reduce library
# e.g. in a cuda environment
logger.info(
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@


# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
@lru_cache(maxsize=1)
def is_hip() -> bool:
return torch.version.hip is not None

Expand All @@ -119,18 +120,22 @@ def is_hip() -> bool:
builtins.FP8_E4M3_MIN = FP8_E4M3_MIN


@lru_cache(maxsize=1)
def is_cuda():
return torch.cuda.is_available() and torch.version.cuda


@lru_cache(maxsize=1)
def is_cuda_alike():
return is_cuda() or is_hip()


@lru_cache(maxsize=1)
def is_hpu() -> bool:
return hasattr(torch, "hpu") and torch.hpu.is_available()


@lru_cache(maxsize=1)
def is_xpu() -> bool:
return hasattr(torch, "xpu") and torch.xpu.is_available()

Expand All @@ -140,6 +145,7 @@ def is_npu() -> bool:
return hasattr(torch, "npu") and torch.npu.is_available()


@lru_cache(maxsize=1)
def is_host_cpu_x86() -> bool:
machine = platform.machine().lower()
return (
Expand All @@ -149,6 +155,7 @@ def is_host_cpu_x86() -> bool:
)


@lru_cache(maxsize=1)
def is_cpu() -> bool:
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()

Expand Down
Loading