diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 11c94775c15..7f555110d9d 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -38,11 +38,6 @@ jobs: cd test/srt python3 -m unittest test_serving_throughput.TestServingThroughput.test_default - - name: Benchmark Serving Latency - timeout-minutes: 10 - run: | - python3 -m sglang.bench_latency --model meta-llama/Meta-Llama-3.1-8B-Instruct --batch-size 1 --input 128 --output 8 - - name: Benchmark Serving Throughput (w/o RadixAttention) timeout-minutes: 10 run: | diff --git a/README.md b/README.md index 305df444d09..223f9624f6e 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ### Method 2: From source ``` # Use the last release branch -git clone -b v0.2.14 https://github.com/sgl-project/sglang.git +git clone -b v0.2.14.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip diff --git a/python/pyproject.toml b/python/pyproject.toml index 4a46adc3fef..7b2741fd216 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.2.14" +version = "0.2.14.post1" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 6a918fbd112..dea910f5772 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -200,14 +200,16 @@ def extend(reqs, model_runner): tree_cache=None, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) - sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND) - return sample_output.batch_next_token_ids, logits_output.next_token_logits, batch + output = model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids = batch.sample(output.next_token_logits) + return next_token_ids, output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids.cpu().numpy()) - sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE) - return sample_output.batch_next_token_ids, logits_output.next_token_logits + output = model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids = batch.sample(output.next_token_logits) + return next_token_ids, output.next_token_logits @torch.inference_mode() diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index b81f3d2a040..63f74d8b026 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -29,7 +29,7 @@ @dataclasses.dataclass -class LogitsProcessorOutput: +class LogitProcessorOutput: # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor # The logprobs of the next tokens. shape: [#seq, vocab_size] @@ -185,7 +185,7 @@ def forward( # Return only last_logits if logprob is not requested if not logits_metadata.return_logprob: - return LogitsProcessorOutput( + return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=None, normalized_prompt_logprobs=None, @@ -209,7 +209,7 @@ def forward( else: output_top_logprobs = None - return LogitsProcessorOutput( + return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=None, @@ -278,7 +278,7 @@ def forward( # Remove the last token logprob for the prefill tokens. input_token_logprobs = input_token_logprobs[:-1] - return LogitsProcessorOutput( + return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 6cb7d0a7c11..3006e765c88 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,6 +1,4 @@ -import dataclasses import logging -from typing import Union import torch from flashinfer.sampling import ( @@ -11,8 +9,6 @@ ) from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.logits_processor import LogitsProcessorOutput - # TODO: move this dict to another place from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -20,71 +16,30 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class SampleOutput: - success: torch.Tensor - probs: torch.Tensor - batch_next_token_ids: torch.Tensor - - class Sampler(CustomOp): def __init__(self): super().__init__() - def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): - # min-token, presence, frequency - if sampling_info.linear_penalties is not None: - logits += sampling_info.linear_penalties - - # repetition - if sampling_info.scaling_penalties is not None: - logits = torch.where( - logits > 0, - logits / sampling_info.scaling_penalties, - logits * sampling_info.scaling_penalties, - ) - - return logits - - def _get_probs( - self, - logits: torch.Tensor, - sampling_info: SamplingBatchInfo, - is_torch_compile: bool = False, - ): + def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): # Post process logits logits = logits.contiguous() logits.div_(sampling_info.temperatures) - if is_torch_compile: - # FIXME: Temporary workaround for unknown bugs in torch.compile - logits.add_(0) - if sampling_info.logit_bias is not None: logits.add_(sampling_info.logit_bias) if sampling_info.vocab_mask is not None: logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) - logits = self._apply_penalties(logits, sampling_info) + logits = sampling_info.penalizer_orchestrator.apply(logits) - return torch.softmax(logits, dim=-1) - - def forward_cuda( - self, - logits: Union[torch.Tensor, LogitsProcessorOutput], - sampling_info: SamplingBatchInfo, - ): - if isinstance(logits, LogitsProcessorOutput): - logits = logits.next_token_logits - - probs = self._get_probs(logits, sampling_info) + probs = torch.softmax(logits, dim=-1) if not global_server_args_dict["disable_flashinfer_sampling"]: max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device ) - if sampling_info.need_min_p_sampling: + if sampling_info.min_ps.any(): probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) batch_next_token_ids, success = min_p_sampling_from_probs( @@ -100,23 +55,18 @@ def forward_cuda( probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps ) - return SampleOutput(success, probs, batch_next_token_ids) - - def forward_native( - self, - logits: Union[torch.Tensor, LogitsProcessorOutput], - sampling_info: SamplingBatchInfo, - ): - if isinstance(logits, LogitsProcessorOutput): - logits = logits.next_token_logits - - probs = self._get_probs(logits, sampling_info, is_torch_compile=True) + if not torch.all(success): + logging.warning("Sampling failed, fallback to top_k=1 strategy") + probs = probs.masked_fill(torch.isnan(probs), 0.0) + argmax_ids = torch.argmax(probs, dim=-1) + batch_next_token_ids = torch.where( + success, batch_next_token_ids, argmax_ids + ) - batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( - probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps - ) + return batch_next_token_ids - return SampleOutput(success, probs, batch_next_token_ids) + def forward_native(): + raise NotImplementedError("Native forward is not implemented yet.") def top_k_top_p_min_p_sampling_from_probs_torch( @@ -137,10 +87,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch( probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) try: - # FIXME: torch.multiomial does not support num_samples = 1 - sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[ - :, :1 - ] + sampled_index = torch.multinomial(probs_sort, num_samples=1) except RuntimeError as e: logger.warning(f"Sampling error: {e}") batch_next_token_ids = torch.zeros( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 926266a628f..f3af821e4ef 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,5 +1,3 @@ -from __future__ import annotations - """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,7 +17,7 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import torch @@ -31,10 +29,6 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -if TYPE_CHECKING: - from sglang.srt.layers.sampler import SampleOutput - - INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -684,17 +678,11 @@ def merge(self, other: "ScheduleBatch"): self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) - def check_sample_results(self, sample_output: SampleOutput): - if not torch.all(sample_output.success): - probs = sample_output.probs - batch_next_token_ids = sample_output.batch_next_token_ids - logging.warning("Sampling failed, fallback to top_k=1 strategy") - probs = probs.masked_fill(torch.isnan(probs), 0.0) - argmax_ids = torch.argmax(probs, dim=-1) - batch_next_token_ids = torch.where( - sample_output.success, batch_next_token_ids, argmax_ids - ) - sample_output.probs = probs - sample_output.batch_next_token_ids = batch_next_token_ids + def sample(self, logits: torch.Tensor): + from sglang.srt.layers.sampler import Sampler + + sampler = Sampler() + + batch_next_token_ids = sampler(logits, self.sampling_info) - return sample_output.batch_next_token_ids + return batch_next_token_ids diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 127f71900ae..65daed43b28 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -31,7 +31,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -505,29 +505,21 @@ def forward_prefill_batch(self, batch: ScheduleBatch): if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - sample_output, logits_output = self.model_runner.forward( - batch, ForwardMode.EXTEND - ) - next_token_ids = batch.check_sample_results(sample_output) + output = self.model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids = batch.sample(output.next_token_logits) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) # Move logprobs to cpu - if logits_output.next_token_logprobs is not None: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs[ - torch.arange( - len(next_token_ids), device=next_token_ids.device - ), - next_token_ids, - ].tolist() - ) - logits_output.input_token_logprobs = ( - logits_output.input_token_logprobs.tolist() - ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() + if output.next_token_logprobs is not None: + output.next_token_logprobs = output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=next_token_ids.device), + next_token_ids, + ].tolist() + output.input_token_logprobs = output.input_token_logprobs.tolist() + output.normalized_prompt_logprobs = ( + output.normalized_prompt_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() @@ -566,14 +558,12 @@ def forward_prefill_batch(self, batch: ScheduleBatch): self.req_to_token_pool.free(req.req_pool_idx) if req.return_logprob: - self.add_logprob_return_values( - i, req, pt, next_token_ids, logits_output - ) + self.add_logprob_return_values(i, req, pt, next_token_ids, output) pt += req.extend_input_len else: assert batch.extend_num_tokens != 0 - logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND) - embeddings = logits_output.embeddings.tolist() + output = self.model_runner.forward(batch, ForwardMode.EXTEND) + embeddings = output.embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): @@ -601,7 +591,7 @@ def add_logprob_return_values( req: Req, pt: int, next_token_ids: List[int], - output: LogitsProcessorOutput, + output: LogitProcessorOutput, ): if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] @@ -683,17 +673,15 @@ def forward_decode_batch(self, batch: ScheduleBatch): batch.prepare_for_decode() # Forward and sample the next tokens - sample_output, logits_output = self.model_runner.forward( - batch, ForwardMode.DECODE - ) - next_token_ids = batch.check_sample_results(sample_output) + output = self.model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids = batch.sample(output.next_token_logits) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) # Move logprobs to cpu - if logits_output.next_token_logprobs is not None: - next_token_logprobs = logits_output.next_token_logprobs[ + if output.next_token_logprobs is not None: + next_token_logprobs = output.next_token_logprobs[ torch.arange(len(next_token_ids), device=next_token_ids.device), next_token_ids, ].tolist() @@ -719,7 +707,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): (next_token_logprobs[i], next_token_id) ) if req.top_logprobs_num > 0: - req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + req.output_top_logprobs.append(output.output_top_logprobs[i]) self.handle_finished_requests(batch) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 40c87af88cf..796db26623f 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -26,18 +26,16 @@ from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.logits_processor import ( + LogitProcessorOutput, LogitsMetadata, LogitsProcessor, - LogitsProcessorOutput, ) -from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, InputMetadata, update_flashinfer_indices, ) -from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import monkey_patch_vllm_all_gather @@ -146,10 +144,6 @@ def __init__( self.flashinfer_kv_indices.clone(), ] - # Sampling inputs - vocab_size = model_runner.model_config.vocab_size - self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size) - self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] if use_torch_compile: @@ -241,7 +235,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): input_metadata = InputMetadata( forward_mode=ForwardMode.DECODE, - sampling_info=self.sampling_info[:bs], batch_size=bs, req_pool_indices=req_pool_indices, seq_lens=seq_lens, @@ -306,35 +299,27 @@ def replay(self, batch: ScheduleBatch): self.flashinfer_handlers[bs], ) - # Sampling inputs - self.sampling_info.inplace_assign(raw_bs, batch.sampling_info) - # Replay torch.cuda.synchronize() self.graphs[bs].replay() torch.cuda.synchronize() - sample_output, logits_output = self.output_buffers[bs] + output = self.output_buffers[bs] # Unpad if bs != raw_bs: - logits_output = LogitsProcessorOutput( - next_token_logits=logits_output.next_token_logits[:raw_bs], + output = LogitProcessorOutput( + next_token_logits=output.next_token_logits[:raw_bs], next_token_logprobs=None, normalized_prompt_logprobs=None, input_token_logprobs=None, input_top_logprobs=None, output_top_logprobs=None, ) - sample_output = SampleOutput( - sample_output.success[:raw_bs], - sample_output.probs[:raw_bs], - sample_output.batch_next_token_ids[:raw_bs], - ) # Extract logprobs if batch.return_logprob: - logits_output.next_token_logprobs = torch.nn.functional.log_softmax( - logits_output.next_token_logits, dim=-1 + output.next_token_logprobs = torch.nn.functional.log_softmax( + output.next_token_logits, dim=-1 ) return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) if return_top_logprob: @@ -342,8 +327,8 @@ def replay(self, batch: ScheduleBatch): forward_mode=ForwardMode.DECODE, top_logprobs_nums=batch.top_logprobs_nums, ) - logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( - logits_output.next_token_logprobs, logits_metadata + output.output_top_logprobs = LogitsProcessor.get_top_logprobs( + output.next_token_logprobs, logits_metadata )[1] - return sample_output, logits_output + return output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e8849962b07..c107b3bc826 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -1,5 +1,3 @@ -from __future__ import annotations - """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +16,7 @@ """ModelRunner runs the forward passes of the models.""" from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional import numpy as np import torch @@ -28,7 +26,6 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo class ForwardMode(IntEnum): @@ -45,7 +42,6 @@ class InputMetadata: """Store all inforamtion of a forward pass.""" forward_mode: ForwardMode - sampling_info: SamplingBatchInfo batch_size: int req_pool_indices: torch.Tensor seq_lens: torch.Tensor @@ -183,7 +179,6 @@ def from_schedule_batch( ): ret = cls( forward_mode=forward_mode, - sampling_info=batch.sampling_info, batch_size=batch.batch_size(), req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, @@ -194,8 +189,6 @@ def from_schedule_batch( top_logprobs_nums=batch.top_logprobs_nums, ) - ret.sampling_info.prepare_penalties() - ret.compute_positions(batch) ret.compute_extend_infos(batch) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0dd9f8c201f..abee152d6fd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -21,7 +21,7 @@ import logging import pkgutil from functools import lru_cache -from typing import Optional, Tuple, Type +from typing import Optional, Type import torch import torch.nn as nn @@ -44,8 +44,6 @@ from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, @@ -517,11 +515,7 @@ def init_cuda_graphs(self): @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): - if ( - self.cuda_graph_runner - and self.cuda_graph_runner.can_run(len(batch.reqs)) - and not batch.sampling_info.has_bias() - ): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): return self.cuda_graph_runner.replay(batch) input_metadata = InputMetadata.from_schedule_batch( @@ -570,9 +564,7 @@ def forward_extend_multi_modal(self, batch: ScheduleBatch): input_metadata.image_offsets, ) - def forward( - self, batch: ScheduleBatch, forward_mode: ForwardMode - ) -> Tuple[SampleOutput, LogitsProcessorOutput]: + def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: return self.forward_extend_multi_modal(batch) elif forward_mode == ForwardMode.DECODE: diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 1c189eebbc0..0a22f994bb4 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -31,18 +31,20 @@ ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata LoraConfig = None @@ -381,11 +383,17 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index c360106f97c..f6d6f6e1f94 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -64,7 +64,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -327,7 +326,6 @@ def __init__( self.config = config self.quant_config = quant_config self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() self.model = CohereModel(config, quant_config) @torch.no_grad() @@ -342,11 +340,9 @@ def forward( positions, input_metadata, ) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index b3a76b56ae2..39ac4aefa72 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -45,7 +45,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -383,7 +382,6 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -393,11 +391,9 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index b939602c1ba..59fd1ec7ed8 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,7 +46,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -386,7 +385,6 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -396,11 +394,9 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 15ecf4bb66b..13dd477392e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -45,7 +45,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -633,7 +632,6 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() def forward( self, @@ -642,11 +640,9 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 61cc5c66ea5..990937f5180 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -37,7 +37,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -288,7 +287,6 @@ def __init__( self.quant_config = quant_config self.model = GemmaModel(config, quant_config=quant_config) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -299,11 +297,9 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return (sample_output, logits_output) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index fabf86b498e..c6dbc7e5569 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -41,7 +41,6 @@ from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -397,7 +396,6 @@ def __init__( self.quant_config = quant_config self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -408,11 +406,9 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index dc828f0142e..94b7f6153cf 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -35,7 +35,6 @@ from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -262,7 +261,6 @@ def __init__( if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -272,11 +270,9 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 85a89ca3edc..4a0a08bf88b 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -46,7 +46,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -298,7 +297,6 @@ def __init__( self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) @@ -315,11 +313,9 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index c0e4d19e128..f2947e991b5 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -263,7 +262,6 @@ def __init__( self.model = InternLM2Model(config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -274,11 +272,9 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.output.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 42e96123035..9de8d33c5c1 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -39,9 +39,8 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -303,7 +302,6 @@ def __init__( self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -312,13 +310,11 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, - ) -> LogitsProcessorOutput: + ) -> LogitProcessorOutput: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def get_module_name(self, name): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index fdf6d28e556..02224971d6a 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama2 import LlamaModel @@ -65,7 +65,7 @@ def forward( (input_metadata.batch_size, self.config.classification_out_size) ).to(input_ids.device) - return LogitsProcessorOutput( + return LogitProcessorOutput( next_token_logits=scores, next_token_logprobs=scores, normalized_prompt_logprobs=scores, diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 0028ae67a8c..49ff1926f39 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -39,7 +39,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -298,7 +297,6 @@ def __init__( self.scale_width = self.config.hidden_size / self.config.dim_model_base self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -316,11 +314,9 @@ def forward( lm_head_weight = self.model.embed_tokens.weight else: lm_head_weight = self.lm_head.weight - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, lm_head_weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index ca38cb03bae..d11f6c95198 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -41,7 +41,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -300,7 +299,6 @@ def __init__( self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() def forward( self, @@ -310,11 +308,9 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 97ac09ee629..b02e925c5a0 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,7 +45,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -334,7 +333,6 @@ def __init__( self.model = MixtralModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -345,11 +343,9 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 4958a812985..93dae9585c3 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,7 +39,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -252,7 +251,6 @@ def __init__( vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -262,11 +260,10 @@ def forward( input_metadata: InputMetadata, ): hidden_states = self.transformer(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + next_tokens = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output + return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 76094b907a7..fcf083e1b5d 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -38,9 +38,8 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata Qwen2Config = None @@ -277,7 +276,6 @@ def __init__( self.model = Qwen2Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @torch.no_grad() @@ -291,11 +289,9 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) if not get_embedding: - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output else: return self.pooler(hidden_states, input_metadata) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index e08695bc61a..9bdbd750660 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -35,8 +35,10 @@ ReplicatedLinear, RowParallelLinear, ) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -47,7 +49,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -365,7 +366,6 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -376,11 +376,20 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output + + def compute_logits( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + logits = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + return logits def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index a3102baabd4..9e10f12f2a2 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -250,7 +249,6 @@ def __init__( self.model = StableLMEpochModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -261,11 +259,9 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 7843f4bd32d..bc70a9018ed 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -21,63 +21,10 @@ class SamplingBatchInfo: top_ps: torch.Tensor = None top_ks: torch.Tensor = None min_ps: torch.Tensor = None - - # Dispatch in CUDA graph - need_min_p_sampling: bool = False - - # Bias Tensors + penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None logit_bias: torch.Tensor = None vocab_mask: torch.Tensor = None - # Penalizer - penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None - linear_penalties: torch.Tensor = None - scaling_penalties: torch.Tensor = None - - def has_bias(self): - return ( - self.logit_bias is not None - or self.vocab_mask is not None - or self.linear_penalties is not None - or self.scaling_penalties is not None - ) - - @classmethod - def dummy_one(cls, max_bs: int, vocab_size: int): - ret = cls(vocab_size=vocab_size) - ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda") - ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda") - ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda") - ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda") - return ret - - def __getitem__(self, key): - if isinstance(key, slice): - # NOTE: We do not use cuda graph when there is bias tensors - assert not self.has_bias() - return SamplingBatchInfo( - vocab_size=self.vocab_size, - temperatures=self.temperatures[key], - top_ps=self.top_ps[key], - top_ks=self.top_ks[key], - min_ps=self.min_ps[key], - need_min_p_sampling=self.need_min_p_sampling, - ) - else: - raise NotImplementedError - - def inplace_assign(self, bs: int, other: SamplingBatchInfo): - # NOTE: We do not use cuda graph when there is bias tensors - assert not self.has_bias() - - self.vocab_size = other.vocab_size - self.need_min_p_sampling = other.need_min_p_sampling - - self.temperatures[:bs] = other.temperatures - self.top_ps[:bs] = other.top_ps - self.top_ks[:bs] = other.top_ks - self.min_ps[:bs] = other.min_ps - @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): device = "cuda" @@ -98,7 +45,6 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): ret.min_ps = torch.tensor( [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device ) - ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) # Each penalizers will do nothing if they evaluate themselves as not required by looking at # the sampling_params of the requests (See {_is_required()} of each penalizers). So this @@ -126,25 +72,6 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): return ret - def prepare_penalties(self): - self.scaling_penalties = None - self.linear_penalties = None - - for penalizer in self.penalizer_orchestrator.penalizers.values(): - if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer): - if penalizer.is_prepared(): - self.scaling_penalties = penalizer.cumulated_repetition_penalties - else: - if penalizer.is_prepared(): - if self.linear_penalties is None: - bs = self.penalizer_orchestrator.batch.batch_size() - self.linear_penalties = torch.zeros( - (bs, self.vocab_size), - dtype=torch.float32, - device="cuda", - ) - self.linear_penalties = penalizer.apply(self.linear_penalties) - def update_regex_vocab_mask(self, batch: ScheduleBatch): bs, reqs = batch.batch_size(), batch.reqs device = "cuda" diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 2d3b0aefa33..37ed2cf9adc 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -180,7 +180,7 @@ def __init__( tp_size=tp_size, dtype=get_dtype_str(torch_dtype), port=port, - mem_fraction_static=0.69, + mem_fraction_static=0.7, trust_remote_code=False, is_embedding=not self.is_generation, ) diff --git a/python/sglang/version.py b/python/sglang/version.py index f3291e93b7d..839b265519b 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.2.14" +__version__ = "0.2.14.post1"