Skip to content

Commit

Permalink
Fix the overhead due to penalizer in bench_latency (#1496)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 23, 2024
1 parent 42a2d82 commit 2854a5e
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 16 deletions.
4 changes: 2 additions & 2 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def correctness_test(

# Decode
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0]):
for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
for i in range(len(reqs)):
output_ids[i].append(next_token_ids[i])
Expand Down Expand Up @@ -311,7 +311,7 @@ def latency_test_run_once(

# Decode
decode_latencies = []
for i in range(output_len):
for i in range(output_len - 1):
torch.cuda.synchronize()
tic = time.time()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
Expand Down
8 changes: 3 additions & 5 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def alloc_token_slots(self, num_tokens: int):
def prepare_for_extend(self, vocab_size: int):
self.forward_mode = ForwardMode.EXTEND

bs = self.batch_size()
bs = len(self.reqs)
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids)
Expand Down Expand Up @@ -509,7 +509,7 @@ def mix_with_running(self, running_batch: "ScheduleBatch"):
self.extend_logprob_start_lens_cpu.extend([0] * running_bs)

def check_decode_mem(self):
bs = self.batch_size()
bs = len(self.reqs)
if self.token_to_kv_pool.available_size() >= bs:
return True

Expand Down Expand Up @@ -680,14 +680,12 @@ def prepare_for_decode(self, input_ids=None):
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
for r in self.reqs
]
else:
self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)

self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.seq_lens.add_(1)

# Alloc mem
bs = self.batch_size()
bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs)

self.req_to_token_pool.req_to_token[
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False

@torch.inference_mode()
def exposed_step(self, recv_reqs: List):
try:
# Recv requests
Expand Down Expand Up @@ -246,7 +247,6 @@ def exposed_step(self, recv_reqs: List):
self.out_pyobjs = []
return ret

@torch.inference_mode()
def forward_step(self):
if self.do_not_get_new_batch and self.current_inflight_req is None:
new_batch = None
Expand Down
6 changes: 2 additions & 4 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,12 @@ def init_multimuldal_info(self, batch: ScheduleBatch):
self.modalities = [r.modalities for r in reqs]

def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets

if self.forward_mode.is_decode():
if True:
self.positions = self.seq_lens - 1
else:
# Deprecated
self.positions = (self.seq_lens - 1) + position_ids_offsets
self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
else:
if True:
self.positions = torch.tensor(
Expand All @@ -119,7 +117,7 @@ def compute_positions(self, batch: ScheduleBatch):
)
else:
# Deprecated
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
self.positions = torch.tensor(
np.concatenate(
[
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ def init_cuda_graphs(self):
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)

@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch):
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(batch)
Expand All @@ -481,7 +480,6 @@ def forward_decode(self, batch: ScheduleBatch):
batch.input_ids, input_metadata.positions, input_metadata
)

@torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(self, batch)
if self.server_args.lora_paths is not None:
Expand All @@ -500,7 +498,6 @@ def forward_extend(self, batch: ScheduleBatch):
get_embedding=True,
)

@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward(
Expand Down
2 changes: 1 addition & 1 deletion scripts/playground/reference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def normal_text(args):
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = 17
max_new_tokens = 16

torch.cuda.set_device(0)

Expand Down

0 comments on commit 2854a5e

Please sign in to comment.