Skip to content

Commit

Permalink
add minor CLI updates
Browse files Browse the repository at this point in the history
  • Loading branch information
santteegt committed Nov 7, 2024
1 parent cabd5d3 commit 8e4f8e4
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 38 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ tsconfig.json
vocs.config.ts

docs/
models/
output/
tests/
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,4 @@ yarn-error.log*

# Ollama
Modelfile
*.gguf
51 changes: 34 additions & 17 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests
import shutil
import subprocess
import typing


class LazyGroup(click.Group):
Expand Down Expand Up @@ -64,7 +65,7 @@ def _lazy_load(self, cmd_name):
return cmd_object


def ping_service(url: str, service_name: str, debug: bool = False) -> bool:
def ping_service(url: str, service_name: str, headers: typing.Dict[str, any] = dict(), debug: bool = False) -> bool:
"""
Ping a service to check its availability.
Expand All @@ -83,14 +84,16 @@ def ping_service(url: str, service_name: str, debug: bool = False) -> bool:
Exception: If the service response is not successful (i.e., non-OK HTTP status).
"""
try:
ping = requests.get(url)
ping = requests.get(url, headers=headers)
if not ping.ok:
raise Exception(f"ERROR: {service_name} (@ {url}). Reason: {ping.reason}")
if debug:
click.echo(click.style(ping.headers, fg="blue"))
click.echo(click.style(ping.json(), fg="blue"))
click.echo(click.style(f"{service_name} connection OK!", fg="green"))
except Exception as e:
click.echo(click.style(f"ERROR: {service_name} (@ {url}) is down. {e}", fg="red"), err=True)
click.echo(click.style(ping.headers, fg="blue")) if debug else None
click.echo("Try again...")
return False
return True
Expand Down Expand Up @@ -217,8 +220,7 @@ def setup(
env_file.unlink()

if init_setup:
shutil.copyfile(f"{ENV_FILE_PATH}.sample", ENV_FILE_PATH)
dotenv.load_dotenv(ENV_FILE_PATH)
env_vars = dict()

# ------------------------------------------------------------------------------------------------------

Expand All @@ -233,39 +235,45 @@ def setup(
llm_provider = "openai" # gaia uses an openai-like API server

click.echo(f"LLM_PROVIDER={llm_provider}") if debug else None
dotenv.set_key(ENV_FILE_PATH, key_to_set="LLM_PROVIDER", value_to_set=llm_provider)
env_vars["LLM_PROVIDER"] = llm_provider

# Set LLM_API_BASE_URL
llm_api_base_url_default = "http://127.0.0.1:11434" if llm_provider == "ollama" else "http://127.0.0.1:8080/v1"
llm_api_server_url = llm_api_base_url_default
llm_api_key = 'empty-api-key'
llm_api_key_set = False
while True:
llm_api_server_url = click.prompt(
"LLM provider API URL",
type=click.STRING,
default=llm_api_base_url_default,
show_default=True
)

# Set LLM_API_KEY
if not llm_api_key_set and llm_provider_chosen == "other" and llm_provider == "openai":
llm_api_key = click.prompt("LLM provider API Key", type=click.STRING, hide_input=True)
llm_api_key_set = True
click.echo(f"LLM_API_KEY value updated") if debug else None
env_vars["LLM_API_KEY"] = llm_api_key

service_url = llm_api_server_url if re.search("(v1)|(v1/)$", llm_api_server_url) else f"{llm_api_server_url}/v1"
service_url += "/models"
if ping_service(service_url, "LLM Provider API", debug=debug):
headers = {
"Authorization": f"Bearer {llm_api_key}"
}
if ping_service(service_url, "LLM Provider API", headers=headers, debug=debug):
break
click.echo(f"LLM_API_BASE_URL={llm_api_server_url}") if debug else None
dotenv.set_key(ENV_FILE_PATH, key_to_set="LLM_API_BASE_URL", value_to_set=llm_api_server_url)

# Set LLM_API_KEY
if llm_provider_chosen == "other" and llm_provider == "openai":
llm_api_key = click.prompt("LLM provider API Key", type=click.STRING, hide_input=True)
click.echo(f"LLM_API_KEY value updated") if debug else None
dotenv.set_key(ENV_FILE_PATH, key_to_set="LLM_API_KEY", value_to_set=llm_api_key)
env_vars["LLM_API_BASE_URL"] = llm_api_server_url

# Set LLM_EMBEDDINGS_*
llm_embeddings_model = click.prompt("Embeddings model Name", type=click.STRING, default="Nomic-embed-text-v1.5", show_default=True)
click.echo(f"LLM_EMBEDDINGS_MODEL={llm_embeddings_model}") if debug else None
dotenv.set_key(ENV_FILE_PATH, key_to_set="LLM_EMBEDDINGS_MODEL", value_to_set=llm_embeddings_model)
env_vars["LLM_EMBEDDINGS_MODEL"] = llm_embeddings_model
llm_embeddings_vector_size = click.prompt("Embeddings Vector Size", type=click.IntRange(min=0, min_open=True), default=768, show_default=True)
click.echo(f"LLM_EMBEDDINGS_VECTOR_SIZE={llm_embeddings_vector_size}") if debug else None
dotenv.set_key(ENV_FILE_PATH, key_to_set="LLM_EMBEDDINGS_VECTOR_SIZE", value_to_set=str(llm_embeddings_vector_size), quote_mode="never")
env_vars["LLM_EMBEDDINGS_VECTOR_SIZE"] = int(llm_embeddings_vector_size)
if llm_provider == "ollama":
# check Ollama is installed
out = subprocess.run(["which", "ollama"], capture_output=True)
Expand All @@ -284,7 +292,7 @@ def setup(
click.echo(click.style(model_info, fg="blue"), err=True) if debug else None
if model_info.stderr.find(b"not found") > 0:
# request model file and load model
embeddings_model_file = click.prompt("Enter the Path to the Embeddings model file", type=click.Path(exists=True, dir_okay=False))
embeddings_model_file = click.prompt("Enter the Absolute Path to the Embeddings model file", type=click.Path(exists=True, dir_okay=False))
# create Modelfile
with open("./models/Modelfile", "w") as f:
f.write(f"FROM {embeddings_model_file}")
Expand Down Expand Up @@ -351,12 +359,21 @@ def setup(
click.echo(click.style(f"A QdrantDB Doker container is already running", fg="yellow"))

click.echo(f"QDRANTDB_URL={qdrantdb_url}") if debug else None
dotenv.set_key(ENV_FILE_PATH, key_to_set="QDRANTDB_URL", value_to_set=qdrantdb_url)
env_vars["QDRANTDB_URL"] = qdrantdb_url

step += 1

# ------------------------------------------------------------------------------------------------------

# Set .env file
click.echo(click.style(f"Saving Pipeline settings in {ENV_FILE_PATH}...", fg="yellow"))
shutil.copyfile(f"{ENV_FILE_PATH}.sample", ENV_FILE_PATH)

for key, val in env_vars.items():
dotenv.set_key(ENV_FILE_PATH, key_to_set=key, value_to_set=val, quote_mode="always" if type(val) == str else "never")

# ------------------------------------------------------------------------------------------------------

click.echo(click.style(f"Pipeline Setup completed!", fg="green"))
click.echo(
f"""
Expand Down
24 changes: 3 additions & 21 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def check_services(settings: Settings) -> bool:
service_url += "/models"
llm_provider_up = ping_service(service_url, settings.llm_provider)
if not llm_provider_up:
click.echo(click.style(f"ERROR: LLM Provider {settings.llm_provider} ({service_url}) is down", fg="red"), err=True)
click.echo(click.style(f"ERROR: LLM Provider {settings.llm_provider} ({service_url}) is down.", fg="red"), err=True)

# QdrantDB
service_url = docker_replace_local_service_url(settings.qdrantdb_url, "qdrant")
qdrantdb_up = ping_service(service_url, "QdrantDB")
if not qdrantdb_up:
click.echo(click.style(f"ERROR: QdrantDB ({service_url}) is down", fg="red"), err=True)
click.echo(click.style(f"ERROR: QdrantDB ({service_url}) is down.", fg="red"), err=True)

# Check all services status
services_ok = llm_provider_up and qdrantdb_up
Expand Down Expand Up @@ -101,7 +101,6 @@ def cli(ctx, debug):
@click.pass_context
@click.argument("api-manifest-file", type=click.Path(exists=True))
@click.argument("openapi-spec-file", type=click.Path(exists=True))
@click.option("--llm-provider", type=click.Choice(["ollama", "openai"], case_sensitive=True), help="Embedding model provider")
@click.option("--api-key", default=lambda: os.environ.get("API_KEY", ""), help="API Auth key", type=click.STRING, prompt=True, prompt_required=False)
@click.option("--source-manifest-file", default="", help="Source YAML manifest", type=click.Path()) # TODO: fix validation when empty
@click.option("--full-refresh", is_flag=True, help="Clean up cache and extract API data from scratch")
Expand All @@ -112,7 +111,6 @@ def all(
ctx,
api_manifest_file: str,
openapi_spec_file: str,
llm_provider: str,
api_key: str,
source_manifest_file: str,
full_refresh: bool,
Expand All @@ -132,7 +130,6 @@ def all(
Args:
ctx (click.Context): The context object for the CLI.
api_manifest_file (str): Path to the API manifest YAML file that defines pipeline config settings and API endpoints.
llm_provider (str): Provider of the embedding model, e.g., "ollama" or "openai".
api_key (str): API authentication key.
openapi_spec_file (str): Path to the OpenAPI YAML specification file.
source_manifest_file (str): Path to the source YAML manifest file. If empty, the API manifest file is used to load data.
Expand All @@ -157,12 +154,9 @@ def all(
openapi_spec_file=openapi_spec_file, # NOTICE: CLI param
source_manifest_file=source_manifest_file # NOTICE: CLI param
)
if llm_provider:
args['llm_provider'] = llm_provider # NOTICE: CLI param over env var
if api_key:
args['api_key'] = api_key # NOTICE: CLI param over env var

# TODO: set env_file based on dev/prod

settings = get_settings(**args)
logger.info(f"Config settings - {settings.model_dump()}")
logger.debug(f"Full refresh? - {full_refresh}")
Expand Down Expand Up @@ -255,13 +249,11 @@ def all(
@cli.command()
@click.pass_context
@click.argument("api-manifest-file", type=click.Path(exists=True))
@click.option("--llm-provider", type=click.Choice(["ollama", "openai"], case_sensitive=True), help="Embedding model provider")
@click.option("--normalized-data-file", required=True, help="Normalized data in JSONL format", type=click.Path(exists=True))
@Timer(name="rag-api-pipeline", text="from-normalized pipeline executed after: {:.2f} seconds", logger=logger.info)
def from_normalized(
ctx,
api_manifest_file: str,
llm_provider: str,
normalized_data_file: str
):
"""
Expand All @@ -275,17 +267,13 @@ def from_normalized(
Args:
ctx (click.Context): The context object for the CLI.
api_manifest_file (str): Path to the API manifest YAML file that defines pipeline config settings and API endpoints.
llm_provider (str): Provider of the embedding model, e.g., "ollama" or "openai".
normalized_data_file (str): Path to the JSONL file containing normalized data.
"""

if not is_pipeline_setup():
return

args = dict()
if llm_provider:
args['llm_provider'] = llm_provider # NOTICE: CLI param over env var

settings = get_settings(*args)

logger.info(f"Config settings - {settings.model_dump()}")
Expand Down Expand Up @@ -350,13 +338,11 @@ def from_normalized(
@cli.command()
@click.pass_context
@click.argument("api-manifest-file", type=click.Path(exists=True))
@click.option("--llm-provider", type=click.Choice(["ollama", "openai"], case_sensitive=True), help="Embedding model provider")
@click.option("--chunked-data-file", required=True, help="Chunked data in JSONL format", type=click.Path(exists=True))
@Timer(name="rag-api-pipeline", text="from-chunked pipeline executed after: {:.2f} seconds", logger=logger.info)
def from_chunked(
ctx,
api_manifest_file: str,
llm_provider: str,
chunked_data_file: str
):
"""
Expand All @@ -370,17 +356,13 @@ def from_chunked(
Args:
ctx (click.Context): The context object for the CLI.
api_manifest_file (str): Path to the API manifest YAML file that defines pipeline config settings and API endpoints.
llm_provider (str): Provider of the embedding model, e.g., "ollama" or "openai".
chunked_data_file (str): Path to the JSONL file containing chunked data.
"""

if not is_pipeline_setup():
return

args = dict()
if llm_provider:
args['llm_provider'] = llm_provider # NOTICE: CLI param over env var

settings = get_settings(**args)

logger.info(f"Config settings - {settings.model_dump()}")
Expand Down

0 comments on commit 8e4f8e4

Please sign in to comment.