diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 23037650a31..e8b25da0704 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,11 +1,12 @@ import logging -from typing import List +from typing import Dict, List import torch from torch import nn from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import crash_on_warnings, is_flashinfer_available @@ -35,6 +36,10 @@ def forward( ): logits = logits_output.next_token_logits + # Apply the custom logit processors if registered in the sampling info. + if sampling_info.has_custom_logit_processor: + self._apply_custom_logit_processor(logits, sampling_info) + if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") logits = torch.where( @@ -121,6 +126,29 @@ def forward( return batch_next_token_ids + def _apply_custom_logit_processor( + self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo + ): + """Apply custom logit processors to the logits. + This function will modify the logits in-place.""" + + for _, ( + processor, + batch_mask, + ) in sampling_batch_info.custom_logit_processor.items(): + # Get the batch indices that need to be processed + batch_indices = batch_mask.nonzero(as_tuple=True)[0] + + # Apply the processor to the logits + logits[batch_mask] = processor( + logits[batch_mask], + [sampling_batch_info.custom_params[i] for i in batch_indices], + ) + + logger.debug( + f"Custom logit processor {processor.__class__.__name__} is applied." + ) + def top_k_top_p_min_p_sampling_from_probs_torch( probs: torch.Tensor, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c5a35ced00c..5a803dd997a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.sampling_params import SamplingParams @@ -69,6 +70,8 @@ class GenerateReqInput: # Session info for continual prompting session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor (serialized function) + custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None def normalize_batch_and_arguments(self): if ( @@ -183,6 +186,13 @@ def normalize_batch_and_arguments(self): else: assert self.parallel_sample_num == 1 + if self.custom_logit_processor is None: + self.custom_logit_processor = [None] * num + elif not isinstance(self.custom_logit_processor, list): + self.custom_logit_processor = [self.custom_logit_processor] * num + else: + assert self.parallel_sample_num == 1 + def regenerate_rid(self): self.rid = uuid.uuid4().hex return self.rid @@ -202,6 +212,11 @@ def __getitem__(self, i): log_metrics=self.log_metrics, modalities=self.modalities[i] if self.modalities else None, lora_path=self.lora_path[i] if self.lora_path is not None else None, + custom_logit_processor=( + self.custom_logit_processor[i] + if self.custom_logit_processor is not None + else None + ), ) @@ -234,6 +249,10 @@ class TokenizedGenerateReqInput: # Session info for continual prompting session_params: Optional[SessionParams] = None + # Custom logit processor (serialized function) + # TODO (hpguo): Add an example and update doc string here + custom_logit_processor: Optional[str] = None + @dataclass class EmbeddingReqInput: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index afbc98b7ca9..a09810a3871 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -232,6 +232,7 @@ def __init__( lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, eos_token_ids: Optional[Set[int]] = None, ): # Input and output info @@ -252,6 +253,7 @@ def __init__( # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path + self.custom_logit_processor = custom_logit_processor # Memory pool info self.req_pool_idx = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5df9c24cee1..a89bd1bc4f4 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -614,6 +614,19 @@ def handle_generate_request( fake_input_ids = [1] * seq_length recv_req.input_ids = fake_input_ids + # Handle custom logit processor passed to the request + custom_logit_processor = recv_req.custom_logit_processor + if ( + not self.server_args.enable_custom_logit_processor + and custom_logit_processor is not None + ): + logger.warning( + "The SGLang server is not configured to enable custom logit processor." + "The custom logit processor passed in will be ignored." + "Please set --enable-custom-logits-processor to enable this feature." + ) + custom_logit_processor = None + req = Req( recv_req.rid, recv_req.input_text, @@ -624,6 +637,7 @@ def handle_generate_request( stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, + custom_logit_processor=custom_logit_processor, eos_token_ids=self.model_config.hf_eos_token_id, ) req.tokenizer = self.tokenizer diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index e9c0c909d52..4f4af636757 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -131,6 +131,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer): sampling_params=req.sampling_params, lora_path=req.lora_path, session_id=self.session_id, + custom_logit_processor=req.custom_logit_processor, ) if last_req is not None: new_req.image_inputs = last_req.image_inputs diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 9cf6d9cc556..3e349300553 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -381,6 +381,7 @@ async def _tokenize_one_request( lora_path=obj.lora_path, input_embeds=input_embeds, session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py new file mode 100644 index 00000000000..a64b2498f23 --- /dev/null +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -0,0 +1,38 @@ +import json +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Dict, List, Optional + +import dill +import torch + + +@lru_cache(maxsize=None) +def _cache_from_str(json_str: str): + """Deserialize a json string to a Callable object. + This function is cached to avoid redundant deserialization. + """ + data = json.loads(json_str) + return dill.loads(bytes.fromhex(data["callable"])) + + +class CustomLogitProcessor(ABC): + """Abstract base class for callable functions.""" + + @abstractmethod + def __call__( + self, + logits: torch.Tensor, + custom_param_list: Optional[List[Dict[str, Any]]] = None, + ) -> torch.Tensor: + """Define the callable behavior.""" + raise NotImplementedError + + def to_str(self) -> str: + """Serialize the callable function to a JSON-compatible string.""" + return json.dumps({"callable": dill.dumps(self).hex()}) + + @classmethod + def from_str(cls, json_str: str): + """Deserialize a callable function from a JSON string.""" + return _cache_from_str(json_str) diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6eda63c706a..d4c5c32386a 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -3,7 +3,7 @@ import dataclasses import logging import threading -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch @@ -14,6 +14,7 @@ from sgl_kernel import sampling_scaling_penalties import sglang.srt.sampling.penaltylib as penaltylib +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor logger = logging.getLogger(__name__) @@ -36,6 +37,9 @@ class SamplingBatchInfo: # Dispatch in CUDA graph need_min_p_sampling: bool + # Whether any request has custom logit processor + has_custom_logit_processor: bool + # Bias Tensors vocab_size: int grammars: Optional[List] = None @@ -52,6 +56,14 @@ class SamplingBatchInfo: # Device device: str = "cuda" + # Custom Parameters + custom_params: Optional[List[Optional[Dict[str, Any]]]] = None + + # Custom Logit Processor + custom_logit_processor: Optional[ + Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] + ] = None + @classmethod def from_schedule_batch( cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool @@ -76,6 +88,36 @@ def from_schedule_batch( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) + # Check if any request has custom logit processor + has_custom_logit_processor = any(r.custom_logit_processor for r in reqs) + + if has_custom_logit_processor: + # Merge the same type of custom logit processors together + processor_dict = {} + for i, r in enumerate(reqs): + if r.custom_logit_processor is None: + continue + processor_str = r.custom_logit_processor + if processor_str not in processor_dict: + processor_dict[processor_str] = [] + processor_dict[processor_str].append(i) + + merged_custom_logit_processor = { + hash(processor_str): ( + # The deserialized custom logit processor object + CustomLogitProcessor.from_str(processor_str), + # The mask tensor for the requests that use this custom logit processor + torch.zeros(len(reqs), dtype=torch.bool) + .scatter_(0, torch.tensor(true_indices), True) + .to(device, non_blocking=True), + ) + for processor_str, true_indices in processor_dict.items() + } + custom_params = [r.sampling_params.custom_params for r in reqs] + else: + merged_custom_logit_processor = None + custom_params = None + ret = cls( temperatures=temperatures, top_ps=top_ps, @@ -83,8 +125,11 @@ def from_schedule_batch( min_ps=min_ps, need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), + has_custom_logit_processor=has_custom_logit_processor, vocab_size=vocab_size, device=device, + custom_params=custom_params, + custom_logit_processor=merged_custom_logit_processor, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -184,6 +229,8 @@ def update_regex_vocab_mask(self): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + if self.has_custom_logit_processor: + self._filter_batch_custom_logit_processor(unfinished_indices, new_indices) for item in [ "temperatures", @@ -196,6 +243,26 @@ def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor) if value is not None: # logit_bias can be None setattr(self, item, value[new_indices]) + def _filter_batch_custom_logit_processor( + self, unfinished_indices: List[int], new_indices: torch.Tensor + ): + """Filter the custom logit processor and custom params""" + if not self.custom_logit_processor: + return + self.custom_logit_processor = { + k: (p, mask[new_indices]) + for k, (p, mask) in self.custom_logit_processor.items() + if any( + mask[new_indices] + ) # ignore the custom logit processor whose mask is all False + } + self.custom_params = [self.custom_params[i] for i in unfinished_indices] + + if len(self) == 0: + self.custom_logit_processor = None + self.custom_params = None + self.has_custom_logit_processor = False + @staticmethod def merge_bias_tensor( lhs: torch.Tensor, @@ -221,6 +288,39 @@ def merge_bias_tensor( return None + @staticmethod + def merge_custom_logit_processor( + lhs: Optional[Dict[str, torch.Tensor]], + rhs: Optional[Dict[str, torch.Tensor]], + bs1: int, + bs2: int, + device: str, + ): + if lhs is None and rhs is None: + return None + lhs, rhs = lhs or {}, rhs or {} + + keys = set(lhs.keys()).union(set(rhs.keys())) + merged_dict = {} + + for k in keys: + # Get the logit processor object + processor = lhs[k][0] if k in lhs else rhs[k][0] + # Get and merge the mask tensors from the two dicts + left_mask = ( + lhs[k][1] + if k in lhs + else torch.zeros(bs1, dtype=torch.bool, device=device) + ) + right_mask = ( + rhs[k][1] + if k in rhs + else torch.zeros(bs2, dtype=torch.bool, device=device) + ) + merged_dict[k] = (processor, torch.cat([left_mask, right_mask])) + + return merged_dict + def merge_batch(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) @@ -240,6 +340,26 @@ def merge_batch(self, other: "SamplingBatchInfo"): ) self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling + # Merge the custom logit processors and custom params lists + if self.has_custom_logit_processor or other.has_custom_logit_processor: + # Merge the custom logit processors + self.custom_logit_processor = ( + SamplingBatchInfo.merge_custom_logit_processor( + self.custom_logit_processor, + other.custom_logit_processor, + len(self), + len(other), + self.device, + ) + ) + # Merge the custom params lists + self.custom_params = self.custom_params or [None] * len(self) + other.custom_params = other.custom_params or [None] * len(other) + self.custom_params.extend(other.custom_params) + + # Set the flag to True if any of the two has custom logit processor + self.has_custom_logit_processor = True + def apply_logits_bias(self, logits: torch.Tensor): # Apply logit_bias if self.logit_bias is not None: diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index d1d932693c6..2224fb0919a 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -13,7 +13,7 @@ # ============================================================================== """Sampling parameters for text generation.""" -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union _SAMPLING_EPS = 1e-6 @@ -48,6 +48,7 @@ def __init__( no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, + custom_params: Optional[Dict[str, Any]] = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -71,6 +72,7 @@ def __init__( self.json_schema = json_schema self.ebnf = ebnf self.no_stop_trim = no_stop_trim + self.custom_params = custom_params # Process some special cases if self.temperature < _SAMPLING_EPS: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index a2c1cb375dc..2cb2cd95dc8 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -773,6 +773,7 @@ def generate( logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, stream: bool = False, ): obj = GenerateReqInput( @@ -784,6 +785,7 @@ def generate( top_logprobs_num=top_logprobs_num, lora_path=lora_path, stream=stream, + custom_logit_processor=custom_logit_processor, ) # get the current event loop @@ -824,6 +826,7 @@ async def async_generate( logprob_start_len: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None, lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[str, List[str]]] = None, stream: bool = False, ): obj = GenerateReqInput( @@ -835,6 +838,7 @@ async def async_generate( top_logprobs_num=top_logprobs_num, lora_path=lora_path, stream=stream, + custom_logit_processor=custom_logit_processor, ) ret = await generate_request(obj, None) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 052e316b7c4..6dd0b945654 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -159,6 +159,9 @@ class ServerArgs: enable_memory_saver: bool = False allow_auto_truncate: bool = False + # Custom logit processor + enable_custom_logit_processor: bool = False + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -865,6 +868,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.", ) + parser.add_argument( + "--enable-custom-logit-processor", + action="store_true", + help="Enable users to pass custom logit processors to the server (disabled by default for security)", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 0fd71efcb0b..7afdc9bf41c 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -5,10 +5,12 @@ import json import unittest +from concurrent.futures import ThreadPoolExecutor import numpy as np import requests +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -24,7 +26,10 @@ def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=("--enable-custom-logit-processor",), ) @classmethod @@ -248,6 +253,62 @@ def test_logprob_grammar(self): self.assertTrue(all(x is not None for x in logprobs)) + def run_custom_logit_processor(self, target_token_id: int): + """Test custom logit processor with custom params.""" + + custom_params = {"token_id": target_token_id} + + class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + assert logits.shape[0] == len(custom_param_list) + key = "token_id" + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits + + prompts = "Question: Is Paris the Capital of France? Answer:" + + # Base case json data to be posted to the server. + base_json = { + "text": prompts, + "sampling_params": {"temperature": 0.0}, + "return_logprob": True, + } + + # Custom json data with custom logit processor and params. + custom_json = base_json.copy() + custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str() + custom_json["sampling_params"]["custom_params"] = custom_params + + custom_response = requests.post( + self.base_url + "/generate", + json=custom_json, + ).json() + + output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"] + sampled_tokens = [x[1] for x in output_token_logprobs] + + # The logit processor should always sample the given token as the logits is deterministic. + self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens)) + + def test_custom_logit_processor(self): + """Test custom logit processor with a single request.""" + self.run_custom_logit_processor(target_token_id=5) + + def test_custom_logit_processor_batch(self): + """Test custom logit processor with a batch of requests.""" + target_token_ids = list(range(32)) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list(executor.map(self.run_custom_logit_processor, target_token_ids)) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json()