diff --git a/docs/backend/quantization.ipynb b/docs/backend/quantization.ipynb new file mode 100644 index 00000000000..da7896e5fed --- /dev/null +++ b/docs/backend/quantization.ipynb @@ -0,0 +1,267 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Quantization\n", + "\n", + "SGLang support various quantization methods, including offline quantization for weight and online dynamic quantization for activation only (we do not recommend online quantization for weights).\n", + "\n", + "For model quantization, you have three options:\n", + "\n", + "1. Use official quantized versions if available (recommended, e.g. official Llama quantized models)\n", + "2. Use third-party quantized versions (e.g. models from neuralmagic collection here)\n", + "3. Quantize the models yourself\n", + "\n", + "## Online Quantization\n", + "\n", + "> Note: Although we support online quantization, users are advised to load offline quantized weights \n", + "\n", + "To enable online quantization, you can simply specify `--quantization` in the command line. For example, if you want to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`, you can launch the server with the following command:\n", + "\n", + "```bash\n", + "python3 -m sglang.launch_server \\\n", + " --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --quantization fp8 \\\n", + " --port 30000 --host 0.0.0.0\n", + "```\n", + "\n", + "which is equivalent to the following code block:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + " NOTE: Typically, the server runs in a separate terminal.
\n", + " In this notebook, we run the server and notebook code together, so their outputs are combined.
\n", + " To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2024-12-27 05:23:35] server_args=ServerArgs(model_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_path='meta-llama/Meta-Llama-3.1-8B-Instruct', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', trust_remote_code=False, dtype='auto', kv_cache_dtype='auto', quantization='fp8', context_length=None, device='cuda', served_model_name='meta-llama/Meta-Llama-3.1-8B-Instruct', chat_template=None, is_embedding=False, revision=None, host='0.0.0.0', port=30000, mem_fraction_static=0.88, max_running_requests=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, cpu_offload_gb=0, tp_size=1, stream_interval=1, random_seed=559844691, constrained_json_whitespace_pattern=None, watchdog_timeout=300, download_dir=None, base_gpu_id=0, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, enable_metrics=False, decode_log_interval=40, api_key=None, file_storage_pth='SGLang_storage', enable_cache_report=False, dp_size=1, load_balance_method='round_robin', ep_size=1, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, lora_paths=None, max_loras_per_batch=8, attention_backend='flashinfer', sampling_backend='flashinfer', grammar_backend='outlines', disable_radix_cache=False, disable_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, disable_mla=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_ep_moe=False, enable_torch_compile=False, torch_compile_max_bs=32, cuda_graph_max_bs=160, torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, num_continuous_decode_steps=1, delete_ckpt_after_loading=False)\n", + "[2024-12-27 05:23:44 TP0] Init torch distributed begin.\n", + "[2024-12-27 05:23:45 TP0] Load weight begin. avail mem=78.58 GB\n", + "[2024-12-27 05:23:46 TP0] Using model weights format ['*.safetensors']\n", + "Loading safetensors checkpoint shards: 0% Completed | 0/4 [00:00 Note: Although we support online quantization, we recommend users to use quantized models. + +To enable online quantization, you can simply specify `--quantization` in the command line. For example, if you want to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`, you can launch the server with the following command: + +```bash +python3 -m sglang.launch_server \ + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \ + --quantization fp8 \ + --port 30000 --host 0.0.0.0 +``` + +Our team is working on supporting more quantization methods. We will soon support other quantization methods including but not limited to `["awq", "gptq", "marlin", "gptq_marlin", "awq_marlin", "bitsandbytes", "gguf"]` + +We also support quantization methods based on [torchao](https://github.com/pytorch/ao). You can simply specify `--torchao-config` in the command line to support this feature. For example, if you want to enable `int4wo-128` for model `meta-llama/Meta-Llama-3.1-8B-Instruct`, you can launch the server with the following command: + +```bash +python3 -m sglang.launch_server \ + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \ + --torchao-config int4wo-128 \ + --port 30000 --host 0.0.0.0 +``` + +We support the following quantization methods based on torchao `["int8dq", "int8wo", "fp8wo", "fp8dq-per_tensor", "fp8dq-per_row", "int4wo-32", "int4wo-64", "int4wo-128", "int4wo-256"]` + +Note: According to [this issue](https://github.com/sgl-project/sglang/issues/2219#issuecomment-2561890230), `"int8dq"` method currently has some bugs when using together with cuda graph capture. So we suggest to disable cuda graph capture when using `"int8dq"` method. Namely, please use the following command: + +```bash +python3 -m sglang.launch_server \ + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \ + --torchao-config int8dq \ + --disable-cuda-graph \ + --port 30000 --host 0.0.0.0 +``` + + +## Offline Quantization + +To do offline quantization for your model, firstly you need to install [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: + +```bash +pip install llmcompressor +``` + +Here, we take quantize `meta-llama/Meta-Llama-3-8B-Instruct` to `FP8` as an example to elaborate on how to do offline quantization. + +```python +from transformers import AutoTokenizer +from llmcompressor.transformers import SparseAutoModelForCausalLM +from llmcompressor.transformers import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier + +# Step 1: Load the original model. +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" + +model = SparseAutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Step 2: Perform offline quantization. +# Step 2.1: Configure the simple PTQ quantization. +recipe = QuantizationModifier( + targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]) + +# Step 2.2: Apply the quantization algorithm. +oneshot(model=model, recipe=recipe) + +# Step 3: Save the model. +SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic" +model.save_pretrained(SAVE_DIR) +tokenizer.save_pretrained(SAVE_DIR) +``` + +Then, you can directly use the quantized model with `SGLang`, by using the following command: + +```bash +python3 -m sglang.launch_server \ + --model-path $PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic \ + --port 30000 --host 0.0.0.0 +``` + +**Note: If the model has already quantized offline, please **do not** add `--quantization` argument when starting the engine.** + + +## Reference + +- [quantization document of vllm](https://docs.vllm.ai/en/latest/quantization/fp8.html) + +- [torchao](https://github.com/pytorch/ao) + +- [llm-compressor](https://github.com/vllm-project/llm-compressor/)