diff --git a/README.md b/README.md index 381235b9..4588cbdd 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ This package integrates Large Language Models (LLMs) into [spaCy](https://spacy. - **[Anthropic](https://docs.anthropic.com/claude/reference/)** - **[Google PaLM](https://ai.google/discover/palm2/)** - **[Microsoft Azure AI](https://azure.microsoft.com/en-us/solutions/ai)** + - **[Groq](https://groq.com/)** - Supports open-source LLMs hosted on Hugging Face 🤗: - **[Falcon](https://huggingface.co/tiiuae)** - **[Dolly](https://huggingface.co/databricks)** diff --git a/spacy_llm/models/__init__.py b/spacy_llm/models/__init__.py index c1427009..e0fe85a3 100644 --- a/spacy_llm/models/__init__.py +++ b/spacy_llm/models/__init__.py @@ -1,12 +1,13 @@ from .hf import dolly_hf, openllama_hf, stablelm_hf from .langchain import query_langchain -from .rest import anthropic, cohere, noop, openai, palm +from .rest import anthropic, cohere, groq, noop, openai, palm __all__ = [ "anthropic", "cohere", "openai", "dolly_hf", + "groq", "noop", "stablelm_hf", "openllama_hf", diff --git a/spacy_llm/models/rest/__init__.py b/spacy_llm/models/rest/__init__.py index 96263967..c6dee8ce 100644 --- a/spacy_llm/models/rest/__init__.py +++ b/spacy_llm/models/rest/__init__.py @@ -1,10 +1,11 @@ -from . import anthropic, azure, base, cohere, noop, openai +from . import anthropic, azure, base, cohere, groq, noop, openai __all__ = [ "anthropic", "azure", "base", "cohere", + "groq", "openai", "noop", ] diff --git a/spacy_llm/models/rest/groq/__init__.py b/spacy_llm/models/rest/groq/__init__.py new file mode 100644 index 00000000..2f816909 --- /dev/null +++ b/spacy_llm/models/rest/groq/__init__.py @@ -0,0 +1,4 @@ +from .model import Endpoints, Groq +from .registry import groq + +__all__ = ["Groq", "Endpoints", "groq"] diff --git a/spacy_llm/models/rest/groq/model.py b/spacy_llm/models/rest/groq/model.py new file mode 100644 index 00000000..b443805f --- /dev/null +++ b/spacy_llm/models/rest/groq/model.py @@ -0,0 +1,164 @@ +import os +import warnings +from enum import Enum +from typing import Any, Dict, Iterable, List, Sized + +import requests # type: ignore[import] +import srsly # type: ignore[import] +from requests import HTTPError + +from ..base import REST + + +class Endpoints(str, Enum): + CHAT = "https://api.groq.com/openai/v1/chat/completions" + NON_CHAT = CHAT # Completion endpoints are not available + + +class Groq(REST): + @property + def credentials(self) -> Dict[str, str]: + # Fetch and check the key + api_key = os.getenv("GROQ_API_KEY") + + if api_key is None: + warnings.warn( + "Could not find the API key to access the OpenAI API. Ensure you have an API key " + "set up via https://console.groq.com/keys, then make it available as " + "an environment variable 'GROQ_API_KEY'." + ) + + # Check the access and get a list of available models to verify the model argument (if not None) + # Even if the model is None, this call is used as a healthcheck to verify access. + headers = { + "Authorization": f"Bearer {api_key}", + } + + return headers + + def _verify_auth(self) -> None: + r = self.retry( + call_method=requests.get, + url="https://api.groq.com/openai/v1/models", + headers=self._credentials, + timeout=self._max_request_time, + ) + if r.status_code == 422: + warnings.warn( + "Could not access api.groq.com -- 422 permission denied." + "Visit https://console.groq.com/keys to check your API keys." + ) + elif r.status_code != 200: + if "Incorrect API key" in r.text: + warnings.warn( + "Authentication with provided API key failed. Please double-check you provided the correct " + "credentials." + ) + else: + warnings.warn( + f"Error accessing api.groq.com({r.status_code}): {r.text}" + ) + + response = r.json()["data"] + models = [response[i]["id"] for i in range(len(response))] + if self._name not in models: + raise ValueError( + f"The specified model '{self._name}' is not available. Choices are: {sorted(set(models))}" + ) + + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: + headers = { + **self._credentials, + "Content-Type": "application/json", + } + all_api_responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self._endpoint, + headers=headers, + json={**json_data, **self._config, "model": self._name}, + timeout=self._max_request_time, + ) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + raise ValueError( + f"Request to Groq API failed: {res_content.get('error', {}).get('message', str(res_content))}" + ) from ex + responses = r.json() + + if "error" in responses: + if self._strict: + raise ValueError(f"API call failed: {responses}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(responses)] + * len(prompts_for_doc) + } + + return responses + + # The Groq API doesn't support NON_CHAT (yet), so we have to send individual requests. + + if self._endpoint == Endpoints.NON_CHAT: + for prompt in prompts_for_doc: + responses = _request( + {"messages": [{"role": "user", "content": prompt}]} + ) + if "error" in responses: + return responses["error"] + + # Process responses. + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + api_responses.append( + response.get("message", {}).get( + "content", srsly.json_dumps(response) + ) + ) + + else: + for prompt in prompts_for_doc: + responses = _request( + {"messages": [{"role": "user", "content": prompt}]} + ) + if "error" in responses: + return responses["error"] + + # Process responses. + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + api_responses.append( + response.get("message", {}).get( + "content", srsly.json_dumps(response) + ) + ) + + all_api_responses.append(api_responses) + + return all_api_responses + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + "gemma2-9b-it": 8192, + "gemma-7b-it": 8192, + "llama-3.1-70b-versatile": 131072, + "llama-3.1-8b-instant": 131072, + "llama3-70b-8192": 8192, + "llama3-8b-8192": 8192, + "llama3-groq-70b-8192-tool-use-preview": 8192, + "llama3-groq-8b-8192-tool-use-preview": 8192, + "llama-guard-3-8b": 8192, + "mixtral-8x7b-32768": 32768, + "whisper-large-v3": 1500, + } diff --git a/spacy_llm/models/rest/groq/registry.py b/spacy_llm/models/rest/groq/registry.py new file mode 100644 index 00000000..c170f941 --- /dev/null +++ b/spacy_llm/models/rest/groq/registry.py @@ -0,0 +1,56 @@ +from typing import Any, Dict, Optional + +from confection import SimpleFrozenDict + +from ....compat import Literal +from ....registry import registry +from .model import Endpoints, Groq + +_DEFAULT_TEMPERATURE = 0.0 + +""" +Parameter explanations: + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint. +""" + + +@registry.llm_models("spacy.groq.v1") +def groq( + config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), + name: str = "llama-3.1-70b-versatile", + strict: bool = Groq.DEFAULT_STRICT, + max_tries: int = Groq.DEFAULT_MAX_TRIES, + interval: float = Groq.DEFAULT_INTERVAL, + max_request_time: float = Groq.DEFAULT_MAX_REQUEST_TIME, + endpoint: Optional[str] = None, + context_length: Optional[int] = None, +) -> Groq: + """Returns Groq instance for 'llama-3.1-70b-versatile' model using REST to prompt API. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + name (str): Model name to use. Can be any model name supported by the Groq API - e. g. 'llama-3.1-70b-versatile', + "llama-3.1-70b-versatile", .... + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Groq): Groq instance for 'llama-3.1-70b-versatile' model. + + DOCS: https://spacy.io/api/large-language-models#models + """ + return Groq( + name=name, + endpoint=endpoint or Endpoints.CHAT.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) diff --git a/usage_examples/ner_groq/README.md b/usage_examples/ner_groq/README.md new file mode 100644 index 00000000..2f967f7e --- /dev/null +++ b/usage_examples/ner_groq/README.md @@ -0,0 +1,75 @@ +# Using GPT Models from OpenAI for Named Entity Recognition (NER) + + +This example shows how you can use a model from OpenAI for Named Entity Recognition (NER). +The NER prompt is based on the [PromptNER](https://arxiv.org/abs/2305.15444) paper and +utilizes Chain-of-Thought reasoning to extract named entities. + +First, create a new API key from +[console.groq.com](https://console.groq.com/keys) or fetch an existing +one. Record the secret key and make sure this is available as an environmental +variable: + +```sh +export GROQ_API_KEY="gsk-..." +``` + +Then, you can run the pipeline on a sample text via: + + +```sh +python run_pipeline.py [TEXT] [PATH TO CONFIG] [PATH TO FILE WITH EXAMPLES] +``` + +For example: + +```sh +python run_pipeline.py \ + ""Sriracha sauce goes really well with hoisin stir fry, but you should add it after you use the wok." \ + ./fewshot.cfg + ./examples.json +``` + +This example assings labels for DISH, INGREDIENT, and EQUIPMENT. + +You can change around the labels and examples for your use case. +You can find the few-shot examples in the +`examples.json` file. Feel free to change and update it to your liking. +We also support other file formats, including `yml` and `jsonl` for these examples. + + +### Negative examples + +While not required, The Chain-of-Thought reasoning for the `spacy.NER.v3` task +works best in our experience when both positive and negative examples are provided. + +This prompts the Language model with concrete examples of what **is not** an entity +for your use case. + +Here's an example that helps define the INGREDIENT label for the LLM. + +```json +[ + { + "text": "You can't get a great chocolate flavor with carob.", + "spans": [ + { + "text": "chocolate", + "is_entity": false, + "label": "==NONE==", + "reason": "is a flavor in this context, not an ingredient" + }, + { + "text": "carob", + "is_entity": true, + "label": "INGREDIENT", + "reason": "is an ingredient to add chocolate flavor" + } + ] + } + ... +] +``` + +In this example, "chocolate" is not an ingredient even though it could be in other contexts. +We explain that via the "reason" property of this example. diff --git a/usage_examples/ner_groq/__init__.py b/usage_examples/ner_groq/__init__.py new file mode 100644 index 00000000..06fab2f6 --- /dev/null +++ b/usage_examples/ner_groq/__init__.py @@ -0,0 +1,3 @@ +from .run_pipeline import run_pipeline + +__all__ = ["run_pipeline"] diff --git a/usage_examples/ner_groq/examples.json b/usage_examples/ner_groq/examples.json new file mode 100644 index 00000000..f38f5e92 --- /dev/null +++ b/usage_examples/ner_groq/examples.json @@ -0,0 +1,36 @@ +[ + { + "text": "You can't get a great chocolate flavor with carob.", + "spans": [ + { + "text": "chocolate", + "is_entity": false, + "label": "==NONE==", + "reason": "is a flavor in this context, not an ingredient" + }, + { + "text": "carob", + "is_entity": true, + "label": "INGREDIENT", + "reason": "is an ingredient to add chocolate flavor" + } + ] + }, + { + "text": "You can probably sand-blast it if it's an anodized aluminum pan", + "spans": [ + { + "text": "sand-blast", + "is_entity": false, + "label": "==NONE==", + "reason": "is a cleaning technique, not some kind of equipment" + }, + { + "text": "anodized aluminum pan", + "is_entity": true, + "label": "EQUIPMENT", + "reason": "is a piece of cooking equipment, anodized is included since it describes the type of pan" + } + ] + } +] diff --git a/usage_examples/ner_groq/fewshot.cfg b/usage_examples/ner_groq/fewshot.cfg new file mode 100644 index 00000000..0d7653cd --- /dev/null +++ b/usage_examples/ner_groq/fewshot.cfg @@ -0,0 +1,31 @@ +[paths] +examples = null + +[nlp] +lang = "en" +pipeline = ["llm"] + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.NER.v3" +labels = ["DISH", "INGREDIENT", "EQUIPMENT"] +description = Entities are the names food dishes, + ingredients, and any kind of cooking equipment. + Adjectives, verbs, adverbs are not entities. + Pronouns are not entities. + +[components.llm.task.label_definitions] +DISH = "Known food dishes, e.g. Lobster Ravioli, garlic bread" +INGREDIENT = "Individual parts of a food dish, including herbs and spices." +EQUIPMENT = "Any kind of cooking equipment. e.g. oven, cooking pot, grill" + +[components.llm.task.examples] +@misc = "spacy.FewShotReader.v1" +path = "${paths.examples}" + +[components.llm.model] +@llm_models = "spacy.groq.v1" diff --git a/usage_examples/ner_groq/run_pipeline.py b/usage_examples/ner_groq/run_pipeline.py new file mode 100644 index 00000000..ac182bbc --- /dev/null +++ b/usage_examples/ner_groq/run_pipeline.py @@ -0,0 +1,29 @@ +from pathlib import Path + +import typer +from wasabi import msg + +from spacy_llm.util import assemble + +Arg = typer.Argument +Opt = typer.Option + + +def run_pipeline( + # fmt: off + text: str = Arg("", help="Text to perform Named Entity Recognition on."), + config_path: Path = Arg(..., help="Path to the configuration file to use."), + examples_path: Path = Arg(..., help="Path to the examples file to use."), + verbose: bool = Opt(False, "--verbose", "-v", help="Show extra information."), + # fmt: on +): + msg.text(f"Loading config from {config_path}", show=verbose) + nlp = assemble(config_path, overrides={"paths.examples": str(examples_path)}) + doc = nlp(text) + + msg.text(f"Text: {doc.text}") + msg.text(f"Entities: {[(ent.text, ent.label_) for ent in doc.ents]}") + + +if __name__ == "__main__": + typer.run(run_pipeline)