Skip to content

Commit 9452b39

Browse files
authored
update doc and launch params (#733)
Co-authored-by: shihaobai <[email protected]>
1 parent 250d7ad commit 9452b39

File tree

8 files changed

+58
-17
lines changed

8 files changed

+58
-17
lines changed

docs/EN/source/getting_started/installation.rst

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ The easiest way to install Lightllm is by using the official image. You can dire
2626
$
2727
$ # Run the image
2828
$ docker run -it --gpus all -p 8080:8080 \
29-
$ --shm-size 1g -v your_local_path:/data/ \
29+
$ --shm-size 32g -v your_local_path:/data/ \
3030
$ ghcr.io/modeltc/lightllm:main /bin/bash
3131
3232
You can also manually build and run the image from the source:
@@ -39,7 +39,7 @@ You can also manually build and run the image from the source:
3939
$
4040
$ # Run the image
4141
$ docker run -it --gpus all -p 8080:8080 \
42-
$ --shm-size 1g -v your_local_path:/data/ \
42+
$ --shm-size 32g -v your_local_path:/data/ \
4343
$ <image_name> /bin/bash
4444
4545
Alternatively, you can use a script to automatically build and run the image:
@@ -81,16 +81,8 @@ NOTE: If you are using torch with cuda 11.x instead, run `pip install nvidia-ncc
8181
.. note::
8282

8383
The Lightllm code has been tested on various GPUs, including V100, A100, A800, 4090, and H800.
84-
If you are using A100, A800, or similar GPUs, it is recommended to install triton==3.0.0:
84+
If you are using A100, A800, or similar GPUs, it is recommended to install triton==3.1.0:
8585

8686
.. code-block:: console
8787
88-
$ pip install triton==3.0.0 --no-deps
89-
90-
If you are using H800, V100, or similar GPUs, it is recommended to install triton-nightly:
91-
92-
.. code-block:: console
93-
94-
$ pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps
95-
96-
For more details, refer to: `issue <https://github.com/triton-lang/triton/issues/3619>`_ and `fix PR <https://github.com/triton-lang/triton/pull/3638>`_
88+
$ pip install triton==3.1.0 --no-deps

docs/EN/source/getting_started/quickstart.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ After downloading the Llama-2-7b-chat model, use the following command in the te
5353
.. note::
5454
The ``--model_dir`` parameter in the above command should be changed to the actual path of your model on your machine.
5555

56+
For the DeepSeek-R1 model on H200, it can be launched with the following command:
57+
58+
.. code-block:: console
59+
60+
$ LOADWORKER=8 python -m lightllm.server.api_server --model_dir ~/models/DeepSeek-R1 --tp 8 --graph_max_batch_size 100
61+
62+
.. note::
63+
LOADWORKER specifies the thread for model loading, which can enhance the speed of model loading. The --graph_max_batch_size parameter specifies the number of cudagraphs to be captured, which will capture graphs for batch sizes ranging from 1 to 100.
64+
65+
5666
3. (Optional) Test the Model Service
5767
--------------------------------------
5868

lightllm/common/quantization/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ def _parse_network_config(self, network_config):
2626
activation_scheme = network_config.get("activation_scheme", "dynamic")
2727
self.static_activation = activation_scheme == "static"
2828
self.hf_quantization_config = hf_quantization_config
29+
self.hf_quantization_method = hf_quantization_config["quant_method"]
30+
self._mapping_quant_method()
31+
32+
def _mapping_quant_method(self):
33+
if self.hf_quantization_method == "fp8":
34+
block_size = self.hf_quantization_config.get("weight_block_size", None)
35+
if block_size == [128, 128]:
36+
self.quant_type = "vllm-fp8w8a8-b128"
37+
else:
38+
# TODO: more quant method
39+
pass
2940

3041
def _parse_custom_cfg(self, custom_cfg_path):
3142
self.quant_cfg = collections.defaultdict(dict)

lightllm/distributed/communication_op.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
import torch.distributed as dist
2525
from torch.distributed import ReduceOp
2626
from lightllm.utils.log_utils import init_logger
27-
from functools import partial
27+
from lightllm.utils.device_utils import has_nvlink
28+
from lightllm.utils.envs_utils import get_env_start_args
2829

2930
original_all_reduce = torch.distributed.all_reduce
3031
original_all_gather_into_tensor = torch.distributed.all_gather_into_tensor
@@ -67,10 +68,13 @@ def lightllm_capture_graph(self):
6768
yield
6869

6970
def set_custom_reduce(self):
70-
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in ["ON", "TRUE", "1"]
71+
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "True").upper() in ["ON", "TRUE", "1"]
7172
world_size = dist.get_world_size()
7273
ranks = list(range(world_size))
7374

75+
if not has_nvlink() or world_size not in [2, 4, 6, 8]:
76+
ENABLE_VLLM_REDUCE = False
77+
7478
# 创建新的 NCCL 组以防止原始 all_reduce 与 cudagraph 卡住
7579
if self.device_group is None:
7680
self.device_group = dist.new_group(ranks, backend="nccl")
@@ -93,11 +97,13 @@ def _all_reduce_closure(input_, op=ReduceOp.SUM, group=self.device_group, async_
9397

9498
def set_custom_gather(self):
9599
ENABLE_CUSTOM_GATHER = os.getenv("ENABLE_CUSTOM_GATHER", "False").upper() in ["ON", "TRUE", "1"]
100+
args = get_env_start_args()
96101
world_size = dist.get_world_size()
97102
ranks = list(range(world_size))
98103
if self.device_group is None:
99104
self.device_group = dist.new_group(ranks, backend="nccl")
100-
if ENABLE_CUSTOM_GATHER and HAS_LIGHTLLM_KERNEL:
105+
106+
if ENABLE_CUSTOM_GATHER and HAS_LIGHTLLM_KERNEL or args.disable_custom_allreduce:
101107
cpu_group = dist.new_group(ranks, backend="gloo")
102108
self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device())
103109
logger.info("Enable Custom ALLGather.")

lightllm/server/api_cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
159159
parser.add_argument(
160160
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional multimodal models."
161161
)
162+
parser.add_argument("--disable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.")
162163
parser.add_argument(
163164
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
164165
)
@@ -225,7 +226,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
225226
parser.add_argument(
226227
"--graph_max_len_in_batch",
227228
type=int,
228-
default=8192,
229+
default=0,
229230
help="""Maximum sequence length that can be captured by the cuda graph for decodign stage.
230231
The default value is 8192. It will turn into eagar mode if encounters a larger value. """,
231232
)

lightllm/server/api_start.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def normal_or_p_d_start(args):
9494
if not args.enable_chunked_prefill:
9595
args.chunked_prefill_size = 0
9696

97+
if args.graph_max_len_in_batch == 0:
98+
args.graph_max_len_in_batch = args.max_req_total_len
99+
97100
# 这些模式不能同时设置。
98101
assert [
99102
args.enable_chunked_prefill,

lightllm/utils/device_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from functools import lru_cache
3+
import subprocess
34

45

56
def set_current_device_id(device_id: int):
@@ -103,3 +104,19 @@ def init_p2p(device_index):
103104
@lru_cache(maxsize=None)
104105
def kv_trans_use_p2p():
105106
return os.getenv("KV_TRANS_USE_P2P", "False").upper() in ["1", "TRUE", "ON"]
107+
108+
109+
def has_nvlink():
110+
try:
111+
# Call nvidia-smi to get the topology matrix
112+
result = subprocess.check_output(["nvidia-smi", "topo", "--matrix"])
113+
result = result.decode("utf-8")
114+
115+
# Check if the output contains 'NVLink'
116+
if "NVLink" in result:
117+
return True
118+
else:
119+
return False
120+
except subprocess.CalledProcessError:
121+
# If there's an error (e.g., nvidia-smi is not installed or another issue), assume no NVLink
122+
return False

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,5 @@ ujson==5.10.0
8484
frozendict==2.4.6
8585
atomics==1.0.3
8686
easydict==1.13
87-
gunicorn==23.0.0
87+
gunicorn==23.0.0
88+
vllm==0.7.2

0 commit comments

Comments
 (0)