Skip to content

Commit

Permalink
Fix logprob in the overlapped mode (#1795)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 25, 2024
1 parent c555ce2 commit e646c59
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 29 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pip install "sglang[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.

### Method 2: From source
```
Expand All @@ -75,7 +75,7 @@ pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.

### Method 3: Using docker
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
Expand Down
4 changes: 2 additions & 2 deletions docs/en/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pip install "sglang[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.

### Method 2: From source
```
Expand All @@ -26,7 +26,7 @@ pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.

### Method 3: Using docker
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ class LogitsProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs: torch.Tensor
next_token_logprobs: torch.Tensor = None

# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
normalized_prompt_logprobs: torch.Tensor = None
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor
input_token_logprobs: torch.Tensor = None

# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List
input_top_logprobs: List = None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List
output_top_logprobs: List = None


@dataclasses.dataclass
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):

if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
Expand Down
41 changes: 38 additions & 3 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def forward_thread_func_(self):
while True:
self.has_inflight_batch = False
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
self.has_inflight_batch = True
self.launch_event = threading.Event()

Expand All @@ -122,19 +124,48 @@ def forward_thread_func_(self):
] = next_token_ids

# Copy results to the CPU
if model_worker_batch.return_logprob:
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].to("cpu", non_blocking=True)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.to(
"cpu", non_blocking=True
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record()

self.launch_event.set()
self.copy_queue.put((copy_event, next_token_ids))
self.copy_queue.put((copy_event, logits_output, next_token_ids))

def copy_thread_func(self):
while True:
copy_event, next_token_ids = self.copy_queue.get()
copy_event, logits_output, next_token_ids = self.copy_queue.get()
if not copy_event:
break
while not copy_event.query():
time.sleep(1e-5)
self.output_queue.put((None, next_token_ids.tolist()))

if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)

self.output_queue.put((logits_output, next_token_ids.tolist()))

def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
Expand Down Expand Up @@ -172,3 +203,7 @@ def update_weights(self, recv_req: UpdateWeightReqInput):
recv_req.model_path, recv_req.load_format
)
return success, message

def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
30 changes: 14 additions & 16 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def run_once():
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
)
return forward(input_ids, forward_batch.positions, forward_batch)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits

for _ in range(2):
torch.cuda.synchronize()
Expand Down Expand Up @@ -318,23 +319,16 @@ def replay(self, forward_batch: ForwardBatch):

# Replay
self.graphs[bs].replay()
logits_output = self.output_buffers[bs]

# Unpad
if bs != raw_bs:
logits_output = LogitsProcessorOutput(
next_token_logits=logits_output.next_token_logits[:raw_bs],
next_token_logprobs=None,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
next_token_logits = self.output_buffers[bs][:raw_bs]

# Extract logprobs
if forward_batch.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
next_token_logprobs=next_token_logprobs,
)
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
Expand All @@ -343,7 +337,11 @@ def replay(self, forward_batch: ForwardBatch):
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata
next_token_logprobs, logits_metadata
)[1]
else:
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
)

return logits_output
1 change: 0 additions & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"test_openai_server.py",
"test_overlap_schedule.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_retract_decode.py",
"test_server_args.py",
"test_skip_tokenizer_init.py",
Expand Down

0 comments on commit e646c59

Please sign in to comment.