Skip to content

Commit

Permalink
[FEAT] Support batches cancel (#1222)
Browse files Browse the repository at this point in the history
Co-authored-by: Yineng Zhang <[email protected]>
  • Loading branch information
caiyueliang and zhyncs authored Aug 26, 2024
1 parent c61a1b6 commit 2f1d928
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 6 deletions.
87 changes: 83 additions & 4 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
end_point = batch_storage[batch_id].endpoint
file_request_list = []
all_requests = []
request_ids = []
for line in lines:
request_data = json.loads(line)
file_request_list.append(request_data)
body = request_data["body"]
request_ids.append(request_data["custom_id"])

# Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request).
Expand All @@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
all_requests.append(ChatCompletionRequest(**body))
elif end_point == "/v1/completions":
all_requests.append(CompletionRequest(**body))

if end_point == "/v1/chat/completions":
adapted_request, request = v1_chat_generate_request(
all_requests, tokenizer_manager
all_requests, tokenizer_manager, request_ids=request_ids
)
elif end_point == "/v1/completions":
adapted_request, request = v1_generate_request(all_requests)
adapted_request, request = v1_generate_request(
all_requests, request_ids=request_ids
)

try:
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
if not isinstance(ret, list):
Expand Down Expand Up @@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
all_ret.append(response_json)
completed_requests += 1

# Write results to a new file
output_file_id = f"backend_result_file-{uuid.uuid4()}"
global storage_dir
Expand Down Expand Up @@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str):
return batch_response


async def v1_cancel_batch(tokenizer_manager, batch_id: str):
# Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id)
if batch_response is None:
raise HTTPException(status_code=404, detail="Batch not found")

# Only do cancal when status is "validating" or "in_progress"
if batch_response.status in ["validating", "in_progress"]:
# Start cancelling the batch asynchronously
asyncio.create_task(
cancel_batch(
tokenizer_manager=tokenizer_manager,
batch_id=batch_id,
input_file_id=batch_response.input_file_id,
)
)

# Update batch status to "cancelling"
batch_response.status = "cancelling"

return batch_response
else:
raise HTTPException(
status_code=500,
detail=f"Current status is {batch_response.status}, no need to cancel",
)


async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
try:
# Update the batch status to "cancelling"
batch_storage[batch_id].status = "cancelling"

# Retrieve the input file content
input_file_request = file_id_request.get(input_file_id)
if not input_file_request:
raise ValueError("Input file not found")

# Parse the JSONL file and process each request
input_file_path = file_id_storage.get(input_file_id)
with open(input_file_path, "r", encoding="utf-8") as f:
lines = f.readlines()

file_request_list = []
request_ids = []
for line in lines:
request_data = json.loads(line)
file_request_list.append(request_data)
request_ids.append(request_data["custom_id"])

# Cancel requests by request_ids
for rid in request_ids:
tokenizer_manager.abort_request(rid=rid)

retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "cancelled"

except Exception as e:
logger.error("error in SGLang:", e)
# Update batch status to "failed"
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed"
retrieve_batch.failed_at = int(time.time())
retrieve_batch.errors = {"message": str(e)}


async def v1_retrieve_file(file_id: str):
# Retrieve the batch job from the in-memory storage
file_response = file_id_response.get(file_id)
Expand All @@ -392,7 +465,9 @@ def iter_file():
return StreamingResponse(iter_file(), media_type="application/octet-stream")


def v1_generate_request(all_requests: List[CompletionRequest]):
def v1_generate_request(
all_requests: List[CompletionRequest], request_ids: List[str] = None
):
prompts = []
sampling_params_list = []
return_logprobs = []
Expand Down Expand Up @@ -464,6 +539,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
logprob_start_len=logprob_start_lens,
return_text_in_logprobs=True,
stream=all_requests[0].stream,
rid=request_ids,
)

if len(all_requests) == 1:
Expand Down Expand Up @@ -746,7 +822,9 @@ async def generate_stream_resp():


def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest], tokenizer_manager
all_requests: List[ChatCompletionRequest],
tokenizer_manager,
request_ids: List[str] = None,
):
input_ids = []
sampling_params_list = []
Expand Down Expand Up @@ -834,6 +912,7 @@ def v1_chat_generate_request(
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
return_text_in_logprobs=True,
rid=request_ids,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
v1_batches,
v1_cancel_batch,
v1_chat_completions,
v1_completions,
v1_delete_file,
Expand Down Expand Up @@ -246,6 +247,12 @@ async def openai_v1_batches(raw_request: Request):
return await v1_batches(tokenizer_manager, raw_request)


@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
# https://platform.openai.com/docs/api-reference/batch/cancel
return await v1_cancel_batch(tokenizer_manager, batch_id)


@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id)
Expand Down
34 changes: 32 additions & 2 deletions test/srt/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,7 @@ def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
index, True
), f"index {index} is not found in the response"

def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
def _create_batch(self, mode, client):
if mode == "completion":
input_file_path = "complete_input.jsonl"
# write content to input file
Expand Down Expand Up @@ -333,9 +332,11 @@ def run_batch(self, mode):
},
},
]

with open(input_file_path, "w") as file:
for line in content:
file.write(json.dumps(line) + "\n")

with open(input_file_path, "rb") as file:
uploaded_file = client.files.create(file=file, purpose="batch")
if mode == "completion":
Expand All @@ -348,6 +349,13 @@ def run_batch(self, mode):
endpoint=endpoint,
completion_window=completion_window,
)

return batch_job, content

def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, content = self._create_batch(mode=mode, client=client)

while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3)
print(
Expand All @@ -371,6 +379,24 @@ def run_batch(self, mode):
]
assert len(results) == len(content)

def run_cancel_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, _ = self._create_batch(mode=mode, client=client)

assert batch_job.status not in ["cancelling", "cancelled"]

batch_job = client.batches.cancel(batch_id=batch_job.id)
assert batch_job.status == "cancelling"

while batch_job.status not in ["failed", "cancelled"]:
batch_job = client.batches.retrieve(batch_job.id)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
time.sleep(3)

assert batch_job.status == "cancelled"

def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
Expand Down Expand Up @@ -414,6 +440,10 @@ def test_batch(self):
for mode in ["completion", "chat"]:
self.run_batch(mode)

def test_calcel_batch(self):
for mode in ["completion", "chat"]:
self.run_cancel_batch(mode)

def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)

Expand Down

0 comments on commit 2f1d928

Please sign in to comment.