Skip to content

Commit

Permalink
Config refactor (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Mar 4, 2024
1 parent 91e603d commit be71768
Show file tree
Hide file tree
Showing 17 changed files with 339 additions and 164 deletions.
4 changes: 2 additions & 2 deletions docs/documentation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _prepare_config(self, config: Config) -> tuple[str, str]:
)
config_path = str(deploy_directory / "ragna.toml")

config.local_cache_root = deploy_directory
config.local_root = deploy_directory

sys.modules["__main__"].__file__ = inspect.getouterframes(
inspect.currentframe()
Expand All @@ -51,7 +51,7 @@ def _prepare_config(self, config: Config) -> tuple[str, str]:
# to source storages.
file.write("from ragna import assistants\n\n")

for assistant in config.components.assistants:
for assistant in config.assistants:
if assistant.__module__ == "__main__":
file.write(f"{inspect.getsource(assistant)}\n\n")
assistant.__module__ = custom_module
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def answer(self, prompt, sources):

from ragna.deploy import Config

config = Config(components={"assistants": [DemoStreamingAssistant]})
config = Config(assistants=[DemoStreamingAssistant])

rest_api = documentation_helpers.RestApi()

Expand Down
44 changes: 27 additions & 17 deletions docs/references/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,46 +47,52 @@ All configuration options can be set or overritten by environment variables by u
file is equivalent to setting `RAGNA_DOCUMENT=ragna.core.LocalDocument`.

For configuration options in subsections, the subsection name needs to be appended to
the prefix, e.g. `RAGNA_COMPONENTS_`. The value needs to be in JSON format. For example
the prefix, e.g. `RAGNA_API_`. The value needs to be in JSON format. For example

```toml
[components]
assistants = [
"ragna.assistants.RagnaDemoAssistant",
[api]
origins = [
"http://localhost:31477",
]
```

is equivalent to
`RAGNA_COMPONENTS_ASSISTANTS='["ragna.assistants.RagnaDemoAssistant"]'`.
is equivalent to `RAGNA_API_ORIGINS='["http://localhost:31477"]'`.

## Configuration options

### `local_cache_root`
### `local_root`

### `document`

[ragna.core.Document][] class to use to upload and read documents.
Local root directory Ragna uses for storing files. See [ragna.local_root][].

### `authentication`

[ragna.deploy.Authentication][] class to use for authenticating users.

### `components`
### `document`

#### `source_storages`
[ragna.core.Document][] class to use to upload and read documents.

### `source_storages`

[ragna.core.SourceStorage][]s to be available for the user to use.

#### `assistants`
### `assistants`

[ragna.core.Assistant][]s to be available for the user to use.

### `api`

#### `hostname`

Hostname the REST API will be bound to.

#### `port`

Port the REST API will be bound to.

#### `url`

1. Hostname and port the REST API server will be bound to.
2. URL of the REST API to be accessed by the web UI.
URL of the REST API to be accessed by the web UI.

#### `origins`

Expand All @@ -106,9 +112,13 @@ external clients.

### `ui`

#### `url`
#### `hostname`

Hostname the web UI will be bound to.

#### `port`

Hostname and port the web UI server will be bound to.
Port the web UI will be bound to.

#### `origins`

Expand Down
2 changes: 2 additions & 0 deletions docs/references/python-api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Python API reference

::: ragna.local_root

::: ragna.core

::: ragna.source_storages
Expand Down
14 changes: 8 additions & 6 deletions ragna-docker.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
local_cache_root = "/var/ragna"
document = "ragna.core.LocalDocument"
local_root = "/var/ragna"
authentication = "ragna.deploy.RagnaDemoAuthentication"

[components]
document = "ragna.core.LocalDocument"
source_storages = [
"ragna.source_storages.Chroma",
"ragna.source_storages.RagnaDemoSourceStorage",
Expand All @@ -13,10 +11,14 @@ assistants = [
]

[api]
url = "http://0.0.0.0:31476"
hostname = "0.0.0.0"
port = 31476
url = "http://localhost:31476"
origins = ["http://localhost:31477"]
database_url = "sqlite:////var/ragna/ragna.db"
root_path = ""

[ui]
url = "http://0.0.0.0:31477"
hostname = "0.0.0.0"
port = 31477
origins = ["http://localhost:31477"]
10 changes: 10 additions & 0 deletions ragna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,13 @@

from . import assistants, core, deploy, source_storages
from .core import Rag

__all__ = [
"__version__",
"Rag",
"assistants",
"core",
"deploy",
"local_root",
"source_storages",
]
19 changes: 18 additions & 1 deletion ragna/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,27 @@
)


def make_directory(path: Union[str, Path]) -> Path:
path = Path(path).expanduser().resolve()
path.mkdir(parents=True, exist_ok=True)
return path


def local_root(path: Optional[Union[str, Path]] = None) -> Path:
"""Get or set the local root directory Ragna uses for storing files.
Defaults to the value of the `RAGNA_LOCAL_ROOT` environment variable or otherwise to
`~/.cache/ragna`.
Args:
path: If passed, this is set as new local root directory.
Returns:
Ragnas local root directory.
"""
global _LOCAL_ROOT
if path is not None:
_LOCAL_ROOT = Path(path).expanduser().resolve()
_LOCAL_ROOT = make_directory(path)

return _LOCAL_ROOT

Expand Down
2 changes: 1 addition & 1 deletion ragna/core/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def get_upload_info(
algorithm=cls._JWT_ALGORITHM,
)
}
metadata = {"path": str(config.local_cache_root / "documents" / str(id))}
metadata = {"path": str(config.local_root / "documents" / str(id))}
return metadata, DocumentUploadParameters(method="PUT", url=url, data=data)

@classmethod
Expand Down
11 changes: 10 additions & 1 deletion ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@


class Rag(Generic[C]):
"""RAG workflow."""
"""RAG workflow.
!!! tip
This class can be imported from `ragna` directly, e.g.
```python
from ragna import Rag
```
"""

def __init__(self) -> None:
self._components: dict[Type[C], C] = {}
Expand Down
11 changes: 4 additions & 7 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@


def app(config: Config) -> FastAPI:
ragna.local_root(config.local_cache_root)
ragna.local_root(config.local_root)

rag = Rag() # type: ignore[var-annotated]
components_map: dict[str, Component] = {
component.display_name(): rag._load_component(component)
for component in itertools.chain(
config.components.source_storages, config.components.assistants
)
for component in itertools.chain(config.source_storages, config.assistants)
}

def get_component(display_name: str) -> Component:
Expand Down Expand Up @@ -100,11 +98,10 @@ async def get_components(_: UserDependency) -> schemas.Components:
documents=sorted(config.document.supported_suffixes()),
source_storages=[
_get_component_json_schema(source_storage)
for source_storage in config.components.source_storages
for source_storage in config.source_storages
],
assistants=[
_get_component_json_schema(assistant)
for assistant in config.components.assistants
_get_component_json_schema(assistant) for assistant in config.assistants
],
)

Expand Down
76 changes: 54 additions & 22 deletions ragna/deploy/_cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,19 @@ def _wizard_builtin() -> Config:
"If the requirements of a selected component are not met, "
"I'll show you instructions how to meet them later."
)
config.components.source_storages = _select_components(
config.source_storages = _select_components(
"source storages",
ragna.source_storages,
SourceStorage, # type: ignore[type-abstract]
)
config.components.assistants = _select_components(
config.assistants = _select_components(
"assistants",
ragna.assistants,
Assistant, # type: ignore[type-abstract]
)

_handle_unmet_requirements(
itertools.chain(config.components.source_storages, config.components.assistants)
itertools.chain(config.source_storages, config.assistants)
)

return config
Expand All @@ -159,11 +159,16 @@ def _select_components(
module: ModuleType,
base_cls: Type[T],
) -> list[Type[T]]:
components = [
obj
for obj in module.__dict__.values()
if isinstance(obj, type) and issubclass(obj, base_cls) and obj is not base_cls
]
components = sorted(
(
obj
for obj in module.__dict__.values()
if isinstance(obj, type)
and issubclass(obj, base_cls)
and obj is not base_cls
),
key=lambda component: component.display_name(),
)
return cast(
list[Type[T]],
questionary.checkbox(
Expand Down Expand Up @@ -238,32 +243,59 @@ def _handle_unmet_requirements(components: Iterable[Type[Component]]) -> None:
def _wizard_common() -> Config:
config = _wizard_builtin()

config.local_cache_root = Path(
config.local_root = Path(
questionary.path(
"Where should local files be stored?",
default=str(config.local_cache_root),
default=str(config.local_root),
qmark=QMARK,
).unsafe_ask()
)

config.api.url = questionary.text(
"At what URL do you want the ragna REST API to be served?",
default=config.api.url,
qmark=QMARK,
).unsafe_ask()
for sub_config, title in [(config.api, "REST API"), (config.ui, "web UI")]:
sub_config.hostname = questionary.text(
f"What hostname do you want to bind the the Ragna {title} to?",
default=sub_config.hostname, # type: ignore[attr-defined]
qmark=QMARK,
).unsafe_ask()

sub_config.port = int(
questionary.text(
f"What port do you want to bind the the Ragna {title} to?",
default=str(sub_config.port), # type: ignore[attr-defined]
qmark=QMARK,
).unsafe_ask()
)

config.api.database_url = questionary.text(
"What is the URL of the database?",
default=f"sqlite:///{config.local_cache_root / 'ragna.db'}",
"What is the URL of the SQL database?",
default=Config(local_root=config.local_root).api.database_url,
qmark=QMARK,
).unsafe_ask()

config.ui.url = questionary.text(
"At what URL do you want the ragna web UI to be served?",
default=config.ui.url,
config.api.url = questionary.text(
"At which URL will the Ragna REST API be served?",
default=Config(
api=dict( # type: ignore[arg-type]
hostname=config.api.hostname,
port=config.api.port,
)
).api.url,
qmark=QMARK,
).unsafe_ask()

config.api.origins = config.ui.origins = [
questionary.text(
"At which URL will the Ragna web UI be served?",
default=Config(
ui=dict( # type: ignore[arg-type]
hostname=config.ui.hostname,
port=config.ui.port,
)
).api.origins[0],
qmark=QMARK,
).unsafe_ask()
]

return config


Expand Down Expand Up @@ -310,8 +342,8 @@ def check_config(config: Config) -> bool:
fully_available = True

for title, components in [
("source storages", config.components.source_storages),
("assistants", config.components.assistants),
("source storages", config.source_storages),
("assistants", config.assistants),
]:
components = cast(list[Type[Component]], components)

Expand Down
Loading

0 comments on commit be71768

Please sign in to comment.