Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] add quantization docs #3253

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
267 changes: 267 additions & 0 deletions docs/backend/quantization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quantization\n",
"\n",
"SGLang support various quantization methods, including offline quantization for weight and online dynamic quantization for activation only (we do not recommend online quantization for weights).\n",
"\n",
"For model quantization, you have three options:\n",
"\n",
"1. Use official quantized versions if available (recommended, e.g. official Llama quantized models)\n",
"2. Use third-party quantized versions (e.g. models from neuralmagic collection here)\n",
"3. Quantize the models yourself\n",
"\n",
"## Online Quantization\n",
"\n",
"> Note: Although we support online quantization, users are advised to load offline quantized weights \n",
"\n",
"To enable online quantization, you can simply specify `--quantization` in the command line. For example, if you want to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`, you can launch the server with the following command:\n",
"\n",
"```bash\n",
"python3 -m sglang.launch_server \\\n",
" --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --quantization fp8 \\\n",
" --port 30000 --host 0.0.0.0\n",
"```\n",
"\n",
"which is equivalent to the following code block:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [
{
"data": {
"text/html": [
"<div style='margin: 20px; padding: 15px; background: #f8f9fa; border-radius: 5px; border: 1px solid #dee2e6;'>\n",
"<strong style='color: #00008B;'>\n",
" NOTE: Typically, the server runs in a separate terminal.<br>\n",
" In this notebook, we run the server and notebook code together, so their outputs are combined.<br>\n",
" To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.<br>\n",
"</strong>\n",
"</div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2024-12-27 05:23:35] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization='fp8', context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, tp_size=1, stream_interval=1, random_seed=559844691, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)\n",
"[2024-12-27 05:23:44 TP0] Init torch distributed begin.\n",
"[2024-12-27 05:23:45 TP0] Load weight begin. avail mem=78.58 GB\n",
"[2024-12-27 05:23:46 TP0] Using model weights format ['*.safetensors']\n",
"Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]\n",
"Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:01, 1.58it/s]\n",
"Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.30it/s]\n",
"Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.23it/s]\n",
"Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.61it/s]\n",
"Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.49it/s]\n",
"\n",
"[2024-12-27 05:23:49 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=69.68 GB\n",
"[2024-12-27 05:23:49 TP0] Memory pool end. avail mem=8.31 GB\n",
"[2024-12-27 05:23:50 TP0] Capture cuda graph begin. This can take up to several minutes.\n",
"100%|██████████| 23/23 [00:08<00:00, 2.60it/s]\n",
"[2024-12-27 05:23:59 TP0] Capture cuda graph end. Time elapsed: 8.85 s\n",
"[2024-12-27 05:23:59 TP0] max_total_num_tokens=493601, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072\n",
"[2024-12-27 05:23:59] INFO: Started server process [1369]\n",
"[2024-12-27 05:23:59] INFO: Waiting for application startup.\n",
"[2024-12-27 05:23:59] INFO: Application startup complete.\n",
"[2024-12-27 05:23:59] ERROR: [Errno 98] error while attempting to bind on address ('0.0.0.0', 30000): address already in use\n",
"[2024-12-27 05:23:59] INFO: Waiting for application shutdown.\n",
"[2024-12-27 05:23:59] INFO: Application shutdown complete.\n",
"[2024-12-27 05:24:00] The server is fired up and ready to roll!\n"
]
}
],
"source": [
"from sglang.utils import (\n",
" execute_shell_command,\n",
" wait_for_server,\n",
" terminate_process,\n",
" print_highlight,\n",
")\n",
"\n",
"server_process = execute_shell_command(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --quantization fp8 --port 30000 --host 0.0.0.0\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(\"http://localhost:30000\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our team is working on supporting more quantization methods. We will soon support other quantization methods including but not limited to `[\"gptq\", \"marlin\", \"gptq_marlin\", \"awq_marlin\", \"bitsandbytes\", \"gguf\"]`\n",
"\n",
"**Note that:** Some of these quantization methods are still under development and may not be fully stable yet.\n",
"\n",
"We also support quantization methods based on [torchao](https://github.com/pytorch/ao). You can simply specify `--torchao-config` in the command line to support this feature. For example, if you want to enable `int4wo-128` for model `meta-llama/Meta-Llama-3.1-8B-Instruct` \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2024-12-27 05:24:59] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization=None, context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, tp_size=1, stream_interval=1, random_seed=122963417, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='int4wo-128', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)\n",
"[2024-12-27 05:25:08 TP0] Init torch distributed begin.\n",
"[2024-12-27 05:25:08 TP0] Load weight begin. avail mem=78.58 GB\n",
"[2024-12-27 05:25:09 TP0] Using model weights format ['*.safetensors']\n",
"Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00<?, ?it/s]\n",
"Loading safetensors checkpoint shards: 25% Completed | 1/4 [00:00<00:02, 1.39it/s]\n",
"Loading safetensors checkpoint shards: 50% Completed | 2/4 [00:01<00:01, 1.23it/s]\n",
"Loading safetensors checkpoint shards: 75% Completed | 3/4 [00:02<00:00, 1.19it/s]\n",
"Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.54it/s]\n",
"Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00, 1.42it/s]\n",
"\n",
"[2024-12-27 05:25:12 TP0] Load weight end. type=LlamaForCausalLM, dtype=torch.bfloat16, avail mem=63.49 GB\n",
"[2024-12-27 05:25:13 TP0] Memory pool end. avail mem=8.32 GB\n",
"[2024-12-27 05:25:13 TP0] Capture cuda graph begin. This can take up to several minutes.\n",
"100%|██████████| 23/23 [00:11<00:00, 2.09it/s]\n",
"[2024-12-27 05:25:24 TP0] Capture cuda graph end. Time elapsed: 11.03 s\n",
"[2024-12-27 05:25:24 TP0] max_total_num_tokens=519297, max_prefill_tokens=16384, max_running_requests=2049, context_len=131072\n",
"[2024-12-27 05:25:25] INFO: Started server process [2125]\n",
"[2024-12-27 05:25:25] INFO: Waiting for application startup.\n",
"[2024-12-27 05:25:25] INFO: Application startup complete.\n",
"[2024-12-27 05:25:25] INFO: Uvicorn running on http://0.0.0.0:30000 (Press CTRL+C to quit)\n",
"[2024-12-27 05:25:25] INFO: 127.0.0.1:40696 - \"GET /v1/models HTTP/1.1\" 200 OK\n",
"[2024-12-27 05:25:26] INFO: 127.0.0.1:40702 - \"GET /get_model_info HTTP/1.1\" 200 OK\n",
"[2024-12-27 05:25:26 TP0] Prefill batch. #new-seq: 1, #new-token: 7, #cached-token: 0, cache hit rate: 0.00%, token usage: 0.00, #running-req: 0, #queue-req: 0\n",
"[2024-12-27 05:25:27] INFO: 127.0.0.1:40704 - \"POST /generate HTTP/1.1\" 200 OK\n",
"[2024-12-27 05:25:27] The server is fired up and ready to roll!\n"
]
}
],
"source": [
"\n",
"from sglang.utils import (\n",
" execute_shell_command,\n",
" wait_for_server,\n",
" terminate_process,\n",
" print_highlight,\n",
")\n",
"\n",
"server_process = execute_shell_command(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --torchao-config int4wo-128 --port 30000 --host 0.0.0.0\n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(\"http://localhost:30000\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We support the following quantization methods based on torchao `[\"int8dq\", \"int8wo\", \"fp8wo\", \"fp8dq-per_tensor\", \"fp8dq-per_row\", \"int4wo-32\", \"int4wo-64\", \"int4wo-128\", \"int4wo-256\"]`\n",
"\n",
"Note: According to [this issue](https://github.com/sgl-project/sglang/issues/2219#issuecomment-2561890230), `\"int8dq\"` method currently has some bugs when using together with cuda graph capture. So we suggest to use `--disable-cuda-graph` capture when using `\"int8dq\"` method. \n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Offline Quantization\n",
"\n",
"To do offline quantization for your model, firstly you need to install [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:\n",
"\n",
"```bash\n",
"pip install llmcompressor\n",
"```\n",
"\n",
"Here, we take quantize `meta-llama/Meta-Llama-3-8B-Instruct` to `FP8` as an example to elaborate on how to do offline quantization.\n",
"\n",
"```python\n",
"from transformers import AutoTokenizer\n",
"from llmcompressor.transformers import SparseAutoModelForCausalLM\n",
"from llmcompressor.transformers import oneshot\n",
"from llmcompressor.modifiers.quantization import QuantizationModifier\n",
"\n",
"# Step 1: Load the original model.\n",
"MODEL_ID = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
"\n",
"model = SparseAutoModelForCausalLM.from_pretrained(\n",
" MODEL_ID, device_map=\"auto\", torch_dtype=\"auto\")\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
"\n",
"# Step 2: Perform offline quantization.\n",
"# Step 2.1: Configure the simple PTQ quantization.\n",
"recipe = QuantizationModifier(\n",
" targets=\"Linear\", scheme=\"FP8_DYNAMIC\", ignore=[\"lm_head\"])\n",
"\n",
"# Step 2.2: Apply the quantization algorithm.\n",
"oneshot(model=model, recipe=recipe)\n",
"\n",
"# Step 3: Save the model.\n",
"SAVE_DIR = MODEL_ID.split(\"/\")[1] + \"-FP8-Dynamic\"\n",
"model.save_pretrained(SAVE_DIR)\n",
"tokenizer.save_pretrained(SAVE_DIR)\n",
"```\n",
"\n",
"Then, you can directly use the quantized model with `SGLang`, by using the following command:\n",
"\n",
"```bash\n",
"python3 -m sglang.launch_server \\\n",
" --model-path $PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic \\\n",
" --port 30000 --host 0.0.0.0\n",
"```\n",
"\n",
"Note: If the model has already quantized offline, please **do not** add `--quantization` argument when starting the engine.\n",
"\n",
"\n",
"## Reference\n",
"\n",
"- [quantization document of vllm](https://docs.vllm.ai/en/latest/quantization/fp8.html)\n",
"\n",
"- [torchao](https://github.com/pytorch/ao)\n",
"\n",
"- [llm-compressor](https://github.com/vllm-project/llm-compressor/)\n"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading
Loading