-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Jack Luar <[email protected]>
- Loading branch information
Showing
8 changed files
with
135 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from huggingface_hub import snapshot_download | ||
import os | ||
|
||
if __name__ == "__main__": | ||
cur_dir = os.path.dirname(os.path.abspath(__file__)) | ||
snapshot_download( | ||
"The-OpenROAD-Project/ORAssistant_Public_Evals", | ||
revision="main", | ||
local_dir=cur_dir, | ||
ignore_patterns=[ | ||
".gitattributes", | ||
"README.md", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
|
||
from dotenv import load_dotenv | ||
from src.vertex_ai import GoogleVertexAILangChain | ||
from src.metrics.accuracy import make_correctness_metric | ||
from deepeval.test_case import LLMTestCase | ||
|
||
cur_dir = os.path.dirname(__file__) | ||
root_dir = os.path.join(cur_dir, "../../") | ||
load_dotenv(os.path.join(root_dir, ".env")) | ||
|
||
if __name__ == "__main__": | ||
model = GoogleVertexAILangChain(model_name="gemini-1.5-pro-002") | ||
cm = make_correctness_metric(model) | ||
test_case = LLMTestCase( | ||
input="The dog chased the cat up the tree, who ran up the tree?", | ||
actual_output="It depends, some might consider the cat, while others might argue the dog.", | ||
expected_output="The cat.", | ||
) | ||
|
||
cm.measure(test_case) | ||
print(cm.score) | ||
print(cm.reason) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
""" | ||
Accuracy related metrics from DeepEval | ||
""" | ||
|
||
from deepeval.metrics import GEval | ||
from deepeval.test_case import LLMTestCaseParams | ||
from deepeval.models.base_model import DeepEvalBaseLLM | ||
|
||
|
||
def make_correctness_metric(model: DeepEvalBaseLLM) -> GEval: | ||
return GEval( | ||
name="Correctness", | ||
criteria="Determine whether the actual output is factually correct based on the expected output.", | ||
evaluation_steps=[ | ||
"Check whether the facts in 'actual output' contradicts any facts in 'expected output'", | ||
"You should also heavily penalize omission of detail", | ||
"Vague language, or contradicting OPINIONS, are OK", | ||
], | ||
evaluation_params=[LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT], | ||
model=model, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" | ||
Code is adapted from https://github.com/meteatamel/genai-beyond-basics/blob/main/samples/evaluation/deepeval/vertex_ai/google_vertex_ai_langchain.py | ||
Custom DeepEvalLLM wrapper. | ||
""" | ||
|
||
from typing import Any | ||
|
||
from langchain_google_vertexai import ChatVertexAI, HarmBlockThreshold, HarmCategory | ||
from deepeval.models.base_model import DeepEvalBaseLLM | ||
|
||
|
||
class GoogleVertexAILangChain(DeepEvalBaseLLM): | ||
"""Class that implements Vertex AI via LangChain for DeepEval""" | ||
|
||
def __init__(self, model_name, *args, **kwargs): | ||
super().__init__(model_name, *args, **kwargs) | ||
|
||
def load_model(self, *args, **kwargs): | ||
# Initialize safety filters for Vertex AI model | ||
# This is important to ensure no evaluation responses are blocked | ||
safety_settings = { | ||
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, | ||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | ||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | ||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | ||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, | ||
} | ||
|
||
return ChatVertexAI( | ||
model_name=self.model_name, | ||
safety_settings=safety_settings, | ||
) | ||
|
||
def generate(self, prompt: str) -> Any: | ||
return self.model.invoke(prompt).content | ||
|
||
async def a_generate(self, prompt: str) -> Any: | ||
response = await self.model.ainvoke(prompt) | ||
return response.content | ||
|
||
def get_model_name(self): | ||
return self.model_name | ||
|
||
|
||
def main(): | ||
model = GoogleVertexAILangChain(model_name="gemini-1.5-pro-002") | ||
prompt = "Write me a joke" | ||
print(f"Prompt: {prompt}") | ||
response = model.generate(prompt) | ||
print(f"Response: {response}") | ||
|
||
|
||
async def main_async(): | ||
model = GoogleVertexAILangChain(model_name="gemini-1.5-pro-002") | ||
prompt = "Write me a joke" | ||
print(f"Prompt: {prompt}") | ||
response = await model.a_generate(prompt) | ||
print(f"Response: {response}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
# asyncio.run(main_async()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters