Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
update zhipu api due to new model and api; repair extra invalid gener…
Browse files Browse the repository at this point in the history
…ate output; update its unittest
  • Loading branch information
better629 committed Jan 17, 2024
1 parent 75cbf9f commit 4e13eac
Show file tree
Hide file tree
Showing 17 changed files with 157 additions and 215 deletions.
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ TIMEOUT: 60 # Timeout for llm invocation

#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
# ZHIPUAI_API_MODEL: "glm-4"

#### if Google Gemini from `https://ai.google.dev/` and API_KEY from `https://makersuite.google.com/app/apikey`.
#### You can set here or export GOOGLE_API_KEY="YOUR_API_KEY"
Expand Down
4 changes: 4 additions & 0 deletions examples/llm_hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ async def main():
# streaming mode, much slower
await llm.acompletion_text(hello_msg, stream=True)

# check completion if exist to test llm complete functions
if hasattr(llm, "completion"):
logger.info(llm.completion(hello_msg))


if __name__ == "__main__":
asyncio.run(main())
6 changes: 4 additions & 2 deletions metagpt/actions/write_code_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ async def run(self, *args, **kwargs) -> CodingContext:
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
format_example=format_example,
)
len1 = len(iterative_code) if iterative_code else 0
len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0
logger.info(
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, "
f"{len(self.context.code_doc.content)=}"
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
f"len(self.context.code_doc.content)={len2}"
)
result, rewrited_code = await self.write_code_review_and_rewrite(
context_prompt, cr_prompt, self.context.code_doc.filename
Expand Down
1 change: 1 addition & 0 deletions metagpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def _update(self):
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("ANTHROPIC_API_KEY")
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
self.zhipuai_api_model = self._get("ZHIPUAI_API_MODEL")
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
Expand Down
4 changes: 4 additions & 0 deletions metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def get_choice_text(self, rsp: dict) -> str:
"""Required to provide the first text of choice"""
return rsp.get("choices")[0]["message"]["content"]

def get_choice_delta_text(self, rsp: dict) -> str:
"""Required to provide the first text of stream choice"""
return rsp.get("choices")[0]["delta"]["content"]

def get_choice_function(self, rsp: dict) -> dict:
"""Required to provide the first function of choice
:param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls",
Expand Down
6 changes: 2 additions & 4 deletions metagpt/provider/general_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,8 @@ def _interpret_response(
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
if stream and (
"text/event-stream" in result.headers.get("Content-Type", "")
or "application/x-ndjson" in result.headers.get("Content-Type", "")
):
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)
Expand Down
92 changes: 24 additions & 68 deletions metagpt/provider/zhipuai/async_sse_client.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : async_sse_client to make keep the use of Event to access response
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
# refs to `zhipuai/core/_sse_client.py`

from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient
import json
from typing import Any, Iterator


class AsyncSSEClient(SSEClient):
async def _aread(self):
data = b""
async for chunk in self._event_source:
for line in chunk.splitlines(True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data

async def async_events(self):
async for chunk in self._aread():
event = Event()
# Split before decoding so splitlines() only uses \r and \n
for line in chunk.splitlines():
# Decode the line.
line = line.decode(self._char_enc)

# Lines starting with a separator are comments and are to be
# ignored.
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
continue

data = line.split(_FIELD_SEPARATOR, 1)
field = data[0]

# Ignore unknown fields.
if field not in event.__dict__:
self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field)
continue

if len(data) > 1:
# From the spec:
# "If value starts with a single U+0020 SPACE character,
# remove it from value."
if data[1].startswith(" "):
value = data[1][1:]
else:
value = data[1]
else:
# If no value is present after the separator,
# assume an empty value.
value = ""
class AsyncSSEClient(object):
def __init__(self, event_source: Iterator[Any]):
self._event_source = event_source

# The data field may come over multiple lines and their values
# are concatenated with each other.
if field == "data":
event.__dict__[field] += value + "\n"
else:
event.__dict__[field] = value

# Events with no data are not dispatched.
if not event.data:
continue

# If the data field ends with a newline, remove it.
if event.data.endswith("\n"):
event.data = event.data[0:-1]

# Empty event names default to 'message'
event.event = event.event or "message"

# Dispatch the event
self._logger.debug("Dispatching %s...", event)
yield event
async def stream(self) -> dict:
if isinstance(self._event_source, bytes):
raise RuntimeError(
f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
)
async for chunk in self._event_source:
line = chunk.decode("utf-8")
if line.startswith(":") or not line:
return

field, _p, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if field == "data":
if value.startswith("[DONE]"):
break
data = json.loads(value)
yield data
59 changes: 19 additions & 40 deletions metagpt/provider/zhipuai/zhipu_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,72 +4,51 @@

import json

import zhipuai
from zhipuai.model_api.api import InvokeType, ModelAPI
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from zhipuai import ZhipuAI
from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT

from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient


class ZhiPuModelAPI(ModelAPI):
@classmethod
def get_header(cls) -> dict:
token = cls._generate_token()
zhipuai_default_headers.update({"Authorization": token})
return zhipuai_default_headers

@classmethod
def get_sse_header(cls) -> dict:
token = cls._generate_token()
headers = {"Authorization": token}
return headers

@classmethod
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
class ZhiPuModelAPI(ZhipuAI):
def split_zhipu_api_url(self):
# use this method to prevent zhipu api upgrading to different version.
# and follow the GeneralAPIRequestor implemented based on openai sdk
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
"""
example:
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
"""
zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
arr = zhipu_api_url.split("/api/")
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
# ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions")
return f"{arr[0]}/api", f"/{arr[1]}"

@classmethod
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
async def arequest(self, stream: bool, method: str, headers: dict, kwargs):
# TODO to make the async request to be more generic for models in http mode.
assert method in ["post", "get"]

base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
base_url, url = self.split_zhipu_api_url()
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
method=method,
url=url,
headers=headers,
stream=stream,
params=kwargs,
request_timeout=zhipuai.api_timeout_seconds,
request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read,
)
return result

@classmethod
async def ainvoke(cls, **kwargs) -> dict:
async def acreate(self, **kwargs) -> dict:
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
headers = cls.get_header()
resp = await cls.arequest(
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
)
headers = self._default_headers
resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs)
resp = resp.decode("utf-8")
resp = json.loads(resp)
if "error" in resp:
raise RuntimeError(
f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
)
return resp

@classmethod
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
async def acreate_stream(self, **kwargs) -> AsyncSSEClient:
"""async sse_invoke"""
headers = cls.get_sse_header()
return AsyncSSEClient(
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
)
headers = self._default_headers
return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs))
65 changes: 20 additions & 45 deletions metagpt/provider/zhipuai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
# -*- coding: utf-8 -*-
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk

import json
from enum import Enum

import openai
import zhipuai
from requests import ConnectionError
from tenacity import (
after_log,
Expand All @@ -15,6 +13,7 @@
stop_after_attempt,
wait_random_exponential,
)
from zhipuai.types.chat.chat_completion import Completion

from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import log_llm_stream, logger
Expand All @@ -35,26 +34,25 @@ class ZhiPuEvent(Enum):
class ZhiPuAILLM(BaseLLM):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
From now, there is only one model named `chatglm_turbo`
From now, support glm-3-turbo、glm-4, and also system_prompt.
"""

def __init__(self):
self.__init_zhipuai(CONFIG)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
self.llm = ZhiPuModelAPI(api_key=self.api_key)

def __init_zhipuai(self, config: CONFIG):
assert config.zhipuai_api_key
zhipuai.api_key = config.zhipuai_api_key
self.api_key = config.zhipuai_api_key
self.model = config.zhipuai_api_model # so far, it support glm-3-turbo、glm-4
# due to use openai sdk, set the api_key but it will't be used.
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
if config.openai_proxy:
# FIXME: openai v1.x sdk has no proxy support
openai.proxy = config.openai_proxy

def _const_kwargs(self, messages: list[dict]) -> dict:
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs

def _update_costs(self, usage: dict):
Expand All @@ -67,57 +65,34 @@ def _update_costs(self, usage: dict):
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")

def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")

def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.invoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()
self._update_costs(usage)
return resp
return resp.model_dump()

async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
resp = await self.llm.acreate(**self._const_kwargs(messages))
usage = resp.get("usage", {})
self._update_costs(usage)
return resp

async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages, timeout=timeout)

async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
async for event in response.async_events():
if event.event == ZhiPuEvent.ADD.value:
content = event.data
async for chunk in response.stream():
finish_reason = chunk.get("choices")[0].get("finish_reason")
if finish_reason == "stop":
usage = chunk.get("usage", {})
else:
content = self.get_choice_delta_text(chunk)
collected_content.append(content)
log_llm_stream(content)
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
content = event.data
logger.error(f"event error: {content}", end="")
elif event.event == ZhiPuEvent.FINISH.value:
"""
event.meta
{
"task_status":"SUCCESS",
"usage":{
"completion_tokens":351,
"prompt_tokens":595,
"total_tokens":946
},
"task_id":"xx",
"request_id":"xxx"
}
"""
meta = json.loads(event.meta)
usage = meta.get("usage")
else:
print(f"zhipuapi else event: {event.data}", end="")

log_llm_stream("\n")

self._update_costs(usage)
Expand Down
1 change: 1 addition & 0 deletions metagpt/utils/file_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def save(self, filename: Path | str, content, dependencies: List[str] = No
"""
pathname = self.workdir / filename
pathname.parent.mkdir(parents=True, exist_ok=True)
content = content if content else "" # avoid `argument must be str, not None` to make it continue
async with aiofiles.open(str(pathname), mode="w") as writer:
await writer.write(content)
logger.info(f"save to: {str(pathname)}")
Expand Down
Loading

0 comments on commit 4e13eac

Please sign in to comment.