Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Groq models #480

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ This package integrates Large Language Models (LLMs) into [spaCy](https://spacy.
- **[Anthropic](https://docs.anthropic.com/claude/reference/)**
- **[Google PaLM](https://ai.google/discover/palm2/)**
- **[Microsoft Azure AI](https://azure.microsoft.com/en-us/solutions/ai)**
- **[Groq](https://groq.com/)**
- Supports open-source LLMs hosted on Hugging Face 🤗:
- **[Falcon](https://huggingface.co/tiiuae)**
- **[Dolly](https://huggingface.co/databricks)**
Expand Down
3 changes: 2 additions & 1 deletion spacy_llm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from .hf import dolly_hf, openllama_hf, stablelm_hf
from .langchain import query_langchain
from .rest import anthropic, cohere, noop, openai, palm
from .rest import anthropic, cohere, groq, noop, openai, palm

__all__ = [
"anthropic",
"cohere",
"openai",
"dolly_hf",
"groq",
"noop",
"stablelm_hf",
"openllama_hf",
Expand Down
3 changes: 2 additions & 1 deletion spacy_llm/models/rest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from . import anthropic, azure, base, cohere, noop, openai
from . import anthropic, azure, base, cohere, groq, noop, openai

__all__ = [
"anthropic",
"azure",
"base",
"cohere",
"groq",
"openai",
"noop",
]
4 changes: 4 additions & 0 deletions spacy_llm/models/rest/groq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .model import Endpoints, Groq
from .registry import groq

__all__ = ["Groq", "Endpoints", "groq"]
164 changes: 164 additions & 0 deletions spacy_llm/models/rest/groq/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized

import requests # type: ignore[import]
import srsly # type: ignore[import]
from requests import HTTPError

from ..base import REST


class Endpoints(str, Enum):
CHAT = "https://api.groq.com/openai/v1/chat/completions"
NON_CHAT = CHAT # Completion endpoints are not available


class Groq(REST):
@property
def credentials(self) -> Dict[str, str]:
# Fetch and check the key
api_key = os.getenv("GROQ_API_KEY")

if api_key is None:
warnings.warn(
"Could not find the API key to access the OpenAI API. Ensure you have an API key "
"set up via https://console.groq.com/keys, then make it available as "
"an environment variable 'GROQ_API_KEY'."
)

# Check the access and get a list of available models to verify the model argument (if not None)
# Even if the model is None, this call is used as a healthcheck to verify access.
headers = {
"Authorization": f"Bearer {api_key}",
}

return headers

def _verify_auth(self) -> None:
r = self.retry(
call_method=requests.get,
url="https://api.groq.com/openai/v1/models",
headers=self._credentials,
timeout=self._max_request_time,
)
if r.status_code == 422:
warnings.warn(
"Could not access api.groq.com -- 422 permission denied."
"Visit https://console.groq.com/keys to check your API keys."
)
elif r.status_code != 200:
if "Incorrect API key" in r.text:
warnings.warn(
"Authentication with provided API key failed. Please double-check you provided the correct "
"credentials."
)
else:
warnings.warn(
f"Error accessing api.groq.com({r.status_code}): {r.text}"
)

response = r.json()["data"]
models = [response[i]["id"] for i in range(len(response))]
if self._name not in models:
raise ValueError(
f"The specified model '{self._name}' is not available. Choices are: {sorted(set(models))}"
)

def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]:
headers = {
**self._credentials,
"Content-Type": "application/json",
}
all_api_responses: List[List[str]] = []

for prompts_for_doc in prompts:
api_responses: List[str] = []
prompts_for_doc = list(prompts_for_doc)

def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
r = self.retry(
call_method=requests.post,
url=self._endpoint,
headers=headers,
json={**json_data, **self._config, "model": self._name},
timeout=self._max_request_time,
)
try:
r.raise_for_status()
except HTTPError as ex:
res_content = srsly.json_loads(r.content.decode("utf-8"))
# Include specific error message in exception.
raise ValueError(
f"Request to Groq API failed: {res_content.get('error', {}).get('message', str(res_content))}"
) from ex
responses = r.json()

if "error" in responses:
if self._strict:
raise ValueError(f"API call failed: {responses}.")
else:
assert isinstance(prompts_for_doc, Sized)
return {
"error": [srsly.json_dumps(responses)]
* len(prompts_for_doc)
}

return responses

# The Groq API doesn't support NON_CHAT (yet), so we have to send individual requests.

if self._endpoint == Endpoints.NON_CHAT:
for prompt in prompts_for_doc:
responses = _request(
{"messages": [{"role": "user", "content": prompt}]}
)
if "error" in responses:
return responses["error"]

# Process responses.
assert len(responses["choices"]) == 1
response = responses["choices"][0]
api_responses.append(
response.get("message", {}).get(
"content", srsly.json_dumps(response)
)
)

else:
for prompt in prompts_for_doc:
responses = _request(
{"messages": [{"role": "user", "content": prompt}]}
)
if "error" in responses:
return responses["error"]

# Process responses.
assert len(responses["choices"]) == 1
response = responses["choices"][0]
api_responses.append(
response.get("message", {}).get(
"content", srsly.json_dumps(response)
)
)

all_api_responses.append(api_responses)

return all_api_responses

@staticmethod
def _get_context_lengths() -> Dict[str, int]:
return {
"gemma2-9b-it": 8192,
"gemma-7b-it": 8192,
"llama-3.1-70b-versatile": 131072,
"llama-3.1-8b-instant": 131072,
"llama3-70b-8192": 8192,
"llama3-8b-8192": 8192,
"llama3-groq-70b-8192-tool-use-preview": 8192,
"llama3-groq-8b-8192-tool-use-preview": 8192,
"llama-guard-3-8b": 8192,
"mixtral-8x7b-32768": 32768,
"whisper-large-v3": 1500,
}
56 changes: 56 additions & 0 deletions spacy_llm/models/rest/groq/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Any, Dict, Optional

from confection import SimpleFrozenDict

from ....compat import Literal
from ....registry import registry
from .model import Endpoints, Groq

_DEFAULT_TEMPERATURE = 0.0

"""
Parameter explanations:
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
"""


@registry.llm_models("spacy.groq.v1")
def groq(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
name: str = "llama-3.1-70b-versatile",
strict: bool = Groq.DEFAULT_STRICT,
max_tries: int = Groq.DEFAULT_MAX_TRIES,
interval: float = Groq.DEFAULT_INTERVAL,
max_request_time: float = Groq.DEFAULT_MAX_REQUEST_TIME,
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
) -> Groq:
"""Returns Groq instance for 'llama-3.1-70b-versatile' model using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the Groq API - e. g. 'llama-3.1-70b-versatile',
"llama-3.1-70b-versatile", ....
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Groq): Groq instance for 'llama-3.1-70b-versatile' model.

DOCS: https://spacy.io/api/large-language-models#models
"""
return Groq(
name=name,
endpoint=endpoint or Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)
75 changes: 75 additions & 0 deletions usage_examples/ner_groq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Using GPT Models from OpenAI for Named Entity Recognition (NER)


This example shows how you can use a model from OpenAI for Named Entity Recognition (NER).
The NER prompt is based on the [PromptNER](https://arxiv.org/abs/2305.15444) paper and
utilizes Chain-of-Thought reasoning to extract named entities.

First, create a new API key from
[console.groq.com](https://console.groq.com/keys) or fetch an existing
one. Record the secret key and make sure this is available as an environmental
variable:

```sh
export GROQ_API_KEY="gsk-..."
```

Then, you can run the pipeline on a sample text via:


```sh
python run_pipeline.py [TEXT] [PATH TO CONFIG] [PATH TO FILE WITH EXAMPLES]
```

For example:

```sh
python run_pipeline.py \
""Sriracha sauce goes really well with hoisin stir fry, but you should add it after you use the wok." \
./fewshot.cfg
./examples.json
```

This example assings labels for DISH, INGREDIENT, and EQUIPMENT.

You can change around the labels and examples for your use case.
You can find the few-shot examples in the
`examples.json` file. Feel free to change and update it to your liking.
We also support other file formats, including `yml` and `jsonl` for these examples.


### Negative examples

While not required, The Chain-of-Thought reasoning for the `spacy.NER.v3` task
works best in our experience when both positive and negative examples are provided.

This prompts the Language model with concrete examples of what **is not** an entity
for your use case.

Here's an example that helps define the INGREDIENT label for the LLM.

```json
[
{
"text": "You can't get a great chocolate flavor with carob.",
"spans": [
{
"text": "chocolate",
"is_entity": false,
"label": "==NONE==",
"reason": "is a flavor in this context, not an ingredient"
},
{
"text": "carob",
"is_entity": true,
"label": "INGREDIENT",
"reason": "is an ingredient to add chocolate flavor"
}
]
}
...
]
```

In this example, "chocolate" is not an ingredient even though it could be in other contexts.
We explain that via the "reason" property of this example.
3 changes: 3 additions & 0 deletions usage_examples/ner_groq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .run_pipeline import run_pipeline

__all__ = ["run_pipeline"]
36 changes: 36 additions & 0 deletions usage_examples/ner_groq/examples.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[
{
"text": "You can't get a great chocolate flavor with carob.",
"spans": [
{
"text": "chocolate",
"is_entity": false,
"label": "==NONE==",
"reason": "is a flavor in this context, not an ingredient"
},
{
"text": "carob",
"is_entity": true,
"label": "INGREDIENT",
"reason": "is an ingredient to add chocolate flavor"
}
]
},
{
"text": "You can probably sand-blast it if it's an anodized aluminum pan",
"spans": [
{
"text": "sand-blast",
"is_entity": false,
"label": "==NONE==",
"reason": "is a cleaning technique, not some kind of equipment"
},
{
"text": "anodized aluminum pan",
"is_entity": true,
"label": "EQUIPMENT",
"reason": "is a piece of cooking equipment, anodized is included since it describes the type of pan"
}
]
}
]
Loading