Skip to content

Commit

Permalink
[Feature] Add sampler custom logits processor (#2396)
Browse files Browse the repository at this point in the history
Signed-off-by: Hongpeng Guo <[email protected]>
  • Loading branch information
hongpeng-guo authored Jan 19, 2025
1 parent 3bcf5ec commit e403d23
Show file tree
Hide file tree
Showing 12 changed files with 302 additions and 4 deletions.
30 changes: 29 additions & 1 deletion python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from typing import List
from typing import Dict, List

import torch
from torch import nn

from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available

Expand Down Expand Up @@ -35,6 +36,10 @@ def forward(
):
logits = logits_output.next_token_logits

# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
self._apply_custom_logit_processor(logits, sampling_info)

if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
Expand Down Expand Up @@ -121,6 +126,29 @@ def forward(

return batch_next_token_ids

def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""

for _, (
processor,
batch_mask,
) in sampling_batch_info.custom_logit_processor.items():
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]

# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
[sampling_batch_info.custom_params[i] for i in batch_indices],
)

logger.debug(
f"Custom logit processor {processor.__class__.__name__} is applied."
)


def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
Expand Down
19 changes: 19 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Dict, List, Optional, Union

from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import SamplingParams


Expand Down Expand Up @@ -69,6 +70,8 @@ class GenerateReqInput:

# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
# Custom logit processor (serialized function)
custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None

def normalize_batch_and_arguments(self):
if (
Expand Down Expand Up @@ -183,6 +186,13 @@ def normalize_batch_and_arguments(self):
else:
assert self.parallel_sample_num == 1

if self.custom_logit_processor is None:
self.custom_logit_processor = [None] * num
elif not isinstance(self.custom_logit_processor, list):
self.custom_logit_processor = [self.custom_logit_processor] * num
else:
assert self.parallel_sample_num == 1

def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
Expand All @@ -202,6 +212,11 @@ def __getitem__(self, i):
log_metrics=self.log_metrics,
modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None,
custom_logit_processor=(
self.custom_logit_processor[i]
if self.custom_logit_processor is not None
else None
),
)


Expand Down Expand Up @@ -234,6 +249,10 @@ class TokenizedGenerateReqInput:
# Session info for continual prompting
session_params: Optional[SessionParams] = None

# Custom logit processor (serialized function)
# TODO (hpguo): Add an example and update doc string here
custom_logit_processor: Optional[str] = None


@dataclass
class EmbeddingReqInput:
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
custom_logit_processor: Optional[str] = None,
eos_token_ids: Optional[Set[int]] = None,
):
# Input and output info
Expand All @@ -252,6 +253,7 @@ def __init__(
# Sampling info
self.sampling_params = sampling_params
self.lora_path = lora_path
self.custom_logit_processor = custom_logit_processor

# Memory pool info
self.req_pool_idx = None
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,19 @@ def handle_generate_request(
fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids

# Handle custom logit processor passed to the request
custom_logit_processor = recv_req.custom_logit_processor
if (
not self.server_args.enable_custom_logit_processor
and custom_logit_processor is not None
):
logger.warning(
"The SGLang server is not configured to enable custom logit processor."
"The custom logit processor passed in will be ignored."
"Please set --enable-custom-logits-processor to enable this feature."
)
custom_logit_processor = None

req = Req(
recv_req.rid,
recv_req.input_text,
Expand All @@ -624,6 +637,7 @@ def handle_generate_request(
stream=recv_req.stream,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
custom_logit_processor=custom_logit_processor,
eos_token_ids=self.model_config.hf_eos_token_id,
)
req.tokenizer = self.tokenizer
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/session_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
sampling_params=req.sampling_params,
lora_path=req.lora_path,
session_id=self.session_id,
custom_logit_processor=req.custom_logit_processor,
)
if last_req is not None:
new_req.image_inputs = last_req.image_inputs
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ async def _tokenize_one_request(
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_params=session_params,
custom_logit_processor=obj.custom_logit_processor,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
Expand Down
38 changes: 38 additions & 0 deletions python/sglang/srt/sampling/custom_logit_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import json
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Dict, List, Optional

import dill
import torch


@lru_cache(maxsize=None)
def _cache_from_str(json_str: str):
"""Deserialize a json string to a Callable object.
This function is cached to avoid redundant deserialization.
"""
data = json.loads(json_str)
return dill.loads(bytes.fromhex(data["callable"]))


class CustomLogitProcessor(ABC):
"""Abstract base class for callable functions."""

@abstractmethod
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
"""Define the callable behavior."""
raise NotImplementedError

def to_str(self) -> str:
"""Serialize the callable function to a JSON-compatible string."""
return json.dumps({"callable": dill.dumps(self).hex()})

@classmethod
def from_str(cls, json_str: str):
"""Deserialize a callable function from a JSON string."""
return _cache_from_str(json_str)
Loading

0 comments on commit e403d23

Please sign in to comment.