Skip to content

Commit

Permalink
Improve code style of sampler (#1168)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Aug 21, 2024
1 parent ac1b74f commit 83e23c6
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 195 deletions.
3 changes: 3 additions & 0 deletions examples/usage/json_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def character_gen(s, name):
name
+ " is a character in Harry Potter. Please fill in the following information about this character.\n"
)
s += "The constrained regex is:\n"
s += character_regex + "\n"
s += "The JSON output is:\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)


Expand Down
2 changes: 1 addition & 1 deletion python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers

Expand Down
101 changes: 101 additions & 0 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import logging

import torch
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
from vllm.model_executor.custom_op import CustomOp

# TODO: move this dict to another place
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo

logger = logging.getLogger(__name__)


class Sampler(CustomOp):
def __init__(self):
super().__init__()

def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Post process logits
logits = logits.contiguous()
logits.div_(sampling_info.temperatures)
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)

if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))

logits = sampling_info.penalizer_orchestrator.apply(logits)

probs = torch.softmax(logits, dim=-1)

if not global_server_args_dict["disable_flashinfer_sampling"]:
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.min_ps.any():
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
)
else:
# Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)

if not torch.all(success):
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
success, batch_next_token_ids, argmax_ids
)

return batch_next_token_ids

def forward_native():
raise NotImplementedError("Native forward is not implemented yet.")


def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
top_ks: torch.Tensor,
top_ps: torch.Tensor,
min_ps: torch.Tensor,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
>= top_ks.view(-1, 1)
] = 0.0
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError as e:
logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros(
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
)
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
return batch_next_token_ids, success

batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
return batch_next_token_ids, success
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Dict, List, Optional, Union

from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.sampling.sampling_params import SamplingParams


@dataclass
Expand Down
199 changes: 10 additions & 189 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,14 @@
from typing import List, Optional, Union

import torch
import torch.distributed as dist
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
from vllm.distributed import get_tensor_model_parallel_group

import sglang.srt.sampling.penaltylib as penaltylib

from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

Expand Down Expand Up @@ -340,14 +332,6 @@ class ScheduleBatch:
return_logprob: bool = False
top_logprobs_nums: List[int] = None

# Batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
logit_bias: torch.Tensor = None

@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_logprob = any(req.return_logprob for req in reqs)
Expand Down Expand Up @@ -395,46 +379,6 @@ def alloc_token_slots(self, num_tokens: int):

return out_cache_loc

def batch_sampling_params(self, vocab_size):
device = "cuda"
bs, reqs = self.batch_size(), self.reqs
self.temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
self.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
)
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
)
self.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)

# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
self.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
vocab_size=vocab_size,
batch=self,
device=device,
Penalizers={
penaltylib.BatchedFrequencyPenalizer,
penaltylib.BatchedMinNewTokensPenalizer,
penaltylib.BatchedPresencePenalizer,
penaltylib.BatchedRepetitionPenalizer,
},
)

# Handle logit bias but only allocate when needed
self.logit_bias = None

def prepare_for_extend(self, vocab_size: int):
bs = self.batch_size()
reqs = self.reqs
Expand Down Expand Up @@ -475,7 +419,7 @@ def prepare_for_extend(self, vocab_size: int):
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]

self.batch_sampling_params(vocab_size)
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)

def mix_with_running(self, running_batch: "ScheduleBatch"):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
Expand Down Expand Up @@ -684,6 +628,8 @@ def prepare_for_decode(self, input_ids=None):
self.req_pool_indices, self.seq_lens - 1
] = self.out_cache_loc

self.sampling_info.update_regex_vocab_mask(self)

def filter_batch(self, unfinished_indices: List[int]):
if unfinished_indices is None or len(unfinished_indices) == 0:
# Filter out all requests
Expand All @@ -704,24 +650,13 @@ def filter_batch(self, unfinished_indices: List[int]):
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs)

self.penalizer_orchestrator.filter(unfinished_indices, new_indices)

for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
"logit_bias",
]:
self_val = getattr(self, item, None)
if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices])
self.sampling_info.filter(unfinished_indices, new_indices)

def merge(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
self.sampling_info.merge(other.sampling_info)

self.reqs.extend(other.reqs)

Expand All @@ -736,125 +671,11 @@ def merge(self, other: "ScheduleBatch"):
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)

for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))

# logit_bias can be None
if self.logit_bias is not None or other.logit_bias is not None:
vocab_size = (
self.logit_bias.shape[1]
if self.logit_bias is not None
else other.logit_bias.shape[1]
)
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
if other.logit_bias is None:
other.logit_bias = torch.zeros(
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])

def sample(self, logits: torch.Tensor):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits = logits.contiguous()
logits.div_(self.temperatures)
if self.logit_bias is not None:
logits.add_(self.logit_bias)

has_regex = any(req.regex_fsm is not None for req in self.reqs)
if has_regex:
allowed_mask = torch.empty_like(logits[0], dtype=torch.bool)
for i, req in enumerate(self.reqs):
if req.regex_fsm is not None:
allowed_mask.zero_()
allowed_mask[
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
] = 1
logits[i].masked_fill_(~allowed_mask, float("-inf"))

logits = self.penalizer_orchestrator.apply(logits)

probs = torch.softmax(logits, dim=-1)

if not global_server_args_dict["disable_flashinfer_sampling"]:
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if self.min_ps.any():
probs = top_k_renorm_prob(probs, self.top_ks)
probs = top_p_renorm_prob(probs, self.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
probs, uniform_samples, self.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
else:
# Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, self.top_ks, self.top_ps, self.min_ps
)
from sglang.srt.layers.sampler import Sampler

if not torch.all(success):
logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
success, batch_next_token_ids, argmax_ids
)

if has_regex:
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
for i, req in enumerate(self.reqs):
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.get_next_state(
req.regex_fsm_state, batch_next_token_ids_cpu[i]
)
sampler = Sampler()

self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
batch_next_token_ids = sampler(logits, self.sampling_info)

return batch_next_token_ids


def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
top_ks: torch.Tensor,
top_ps: torch.Tensor,
min_ps: torch.Tensor,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
min_p_thresholds = probs_sort[:, 0] * min_ps
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
>= top_ks.view(-1, 1)
] = 0.0
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError as e:
logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros(
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
)
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
return batch_next_token_ids, success

batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
return batch_next_token_ids, success
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
UpdateWeightReqOutput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
from sglang.utils import get_exception_traceback
Expand Down
Loading

0 comments on commit 83e23c6

Please sign in to comment.