Skip to content

Commit

Permalink
implement stream preprocessors
Browse files Browse the repository at this point in the history
  • Loading branch information
santteegt committed Aug 31, 2024
1 parent 2ba77d0 commit 13afb0a
Show file tree
Hide file tree
Showing 11 changed files with 305 additions and 120 deletions.
42 changes: 21 additions & 21 deletions config/mappings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -164,37 +164,37 @@ textSchemas:
Protocol:
type: object
properties:
cname:
type: string
name:
type: string
categories:
type: array
items:
- cname:
type: string
- name:
type: string
- categories:
type: array
items:
type: string
Proposal:
type: object
properties:
title:
type: string
content:
type: string
summary:
type: string
- title:
type: string
- content:
type: string
- summary:
type: string
DiscourseTopic:
type: object
properties:
title:
type: string
- title:
type: string
DiscourseCategory:
type: object
properties:
name:
type: string
description:
type: string
- name:
type: string
- description:
type: string
DiscourseTopicPost:
type: object
properties:
body:
type: string
- body:
type: string
91 changes: 61 additions & 30 deletions gaianet_rag_api_pipeline/input.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,88 @@
from gaianet_rag_api_pipeline.config import get_settings, Settings
from gaianet_rag_api_pipeline.io.airbyte import AirbyteAPIConnector
from gaianet_rag_api_pipeline.schema import BoardroomAPI # TODO:
from gaianet_rag_api_pipeline.schema import PaginationSchemas

import airbyte as ab
import pathlib
import pathway as pw
import typing


def input(
def create_endpoint_stream(
api_name: str,
settings: Settings,
cache_name: str,
cache_dir: str,
source_manifest: dict | pathlib.Path,
stream_id: str,
pagination_schema: PaginationSchemas,
config: dict[str, typing.Any] | None = None,
source_manifest: bool | dict | pathlib.Path | str = False,
endpoint_text_fields: dict = dict(),
streams: str | list[str] | None = None,
force_full_refresh: bool = False
):
# settings = get_settings()
# class InputSchema(pw.Schema):
# value: int
# format="json"
# return pw.io.python.read(
# InfiniteStream(),
# schema=InputSchema,
# format=format,
# autocommit_duration_ms=get_settings().autocommit_duration_ms,
# )

) -> pw.Table:
cache = ab.caches.new_local_cache(
cache_name=f"{api_name}_cache",
cache_dir='./connector_cache', # NOTICE: parametrize
cleanup=force_full_refresh # NOTICE: CLI param
cache_name=cache_name,
cache_dir=cache_dir,
cleanup=force_full_refresh
)

api_connector = AirbyteAPIConnector(
stream_connector = AirbyteAPIConnector(
name=api_name,
cache=cache,
config=config,
source_manifest=source_manifest,
# streams="proposals", # NOTICE: doing just an individual stream
streams=streams,
force_full_refresh=force_full_refresh, # NOTICE: CLI param
check_source=False, # NOTICE: enforce to not check manifest during initialization
streams=stream_id,
force_full_refresh=force_full_refresh,
check_source=False,
)

try:
api_connector.check()
stream_connector.check()
except Exception as error:
print("FATAL error: manifest error", error) # TODO: logger
print(f"FATAL error: manifest error when creating {stream_id} stream", error) # TODO: logger
raise error

return pw.io.python.read(
api_connector,
schema=BoardroomAPI # TODO: get this from mappings? or define standard base schema
stream_connector,
schema=pagination_schema.value
)


def input(
api_name: str,
settings: Settings,
source_manifest: dict | pathlib.Path,
endpoints: dict,
pagination_schema: PaginationSchemas,
cache_dir: str = "./output/cache",
config: dict[str, typing.Any] | None = None,
force_full_refresh: bool = False
) -> typing.List[pw.Table]:
# settings = get_settings()
# class InputSchema(pw.Schema):
# value: int
# format="json"
# return pw.io.python.read(
# InfiniteStream(),
# schema=InputSchema,
# format=format,
# autocommit_duration_ms=get_settings().autocommit_duration_ms,
# )

streams = [details.get("stream_id", "") for endpoint, details in endpoints.items()]

print(f"input streams - {streams}") # TODO: logger

stream_tables = list()
for stream_id in streams:
stream = create_endpoint_stream(
api_name=api_name,
cache_name=f"{api_name}_{stream_id}_cache",
cache_dir=cache_dir,
source_manifest=source_manifest,
stream_id=stream_id,
pagination_schema=pagination_schema,
config=config,
force_full_refresh=force_full_refresh,
)
stream_tables.append(stream)

return stream_tables
79 changes: 57 additions & 22 deletions gaianet_rag_api_pipeline/loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from gaianet_rag_api_pipeline.schema import PaginationSchemas
from gaianet_rag_api_pipeline.utils import resolve_refs

from openapi_spec_validator import validate_spec
Expand Down Expand Up @@ -109,43 +110,53 @@ def generate_source_manifest(
entrypoint_type = data_root.get("type", "")
#####

# validate textSchemas exist in response schema
# fields specified in textSchemas will be extracted to be preprocessed
# other response data fields will be included as json metadata
text_properties = details.\
get("textSchema", {}).\
get("properties", {})
data_fields = resolve_refs(data_root, openapi_spec)

# get response data fields
data_fields = resolve_refs(data_root, openapi_spec)
if entrypoint_type == "array":
# need to get properties from nested items spec
data_fields = data_fields.\
get("items", {})
data_fields = data_fields.get("properties", {})
fields_list = list(data_fields.keys())
print(f"endpoint text fields: {text_properties}") # TODO: logger
print(f"endpoint spec data fields: {data_fields}") # TODO: logger

# should validate textSchemas exist in response schema
# fields specified in textSchemas will be extracted to be preprocessed
# other response data fields will be included as json metadata
text_schemas = details.\
get("textSchema", {}).\
get("properties", {})

# validate text properties are in the endpoint openapi spec as response data fields
for field, props in text_properties.items():
text_properties = list()
for text_schema in text_schemas:
field = list(text_schema.keys())[0]
props = text_schema[field]
if field not in fields_list or props.get("type", "") != data_fields[field].get("type", ""):
error_msg = f"endpoint field not found or mismatch in openapi spec: {field} - {props}"
print(error_msg) # TODO: logger
raise Exception(error_msg)
text_properties.append(dict(field=field, **props))
print(f"endpoint text fields: {text_properties}") # TODO: logger

# build endpoints ids
stream_id = details.get("id", "")
stream_refId = f"{stream_id}_stream"

# update endpoint pre-process text fields
endpoint_text_fields[endpoint] = text_properties
endpoint_text_fields[endpoint] = dict(
stream_id=stream_id,
entrypoint_type=entrypoint_type,
text_properties=text_properties,
)

# setup pagination
needs_pagination = entrypoint_type == "array"
print(f"response schema: needs pagination? {needs_pagination} - {response_schema}") # TODO: logger
#####

# build endpoint stream definition
endpoint_id = details.get("id", "")
stream_id = f"{endpoint_id}_stream"
stream_definitions[stream_id] = {
stream_definitions[stream_refId] = {
"$ref": f"#/definitions/{'paging_stream' if needs_pagination else 'single_stream'}" ,
"schema_loader": {
"type": "InlineSchemaLoader",
Expand All @@ -154,17 +165,17 @@ def generate_source_manifest(
},
},
"$parameters": {
"name": endpoint_id,
"name": stream_id,
"primary_key": response_primary_key,
"path": f'"{endpoint_path}"',
},
**request_options
}

stream_refs.append(f"#/definitions/{stream_id}")
stream_names.append(endpoint_id)
stream_refs.append(f"#/definitions/{stream_refId}")
stream_names.append(stream_id)

stream_yaml_spec = yaml.dump(stream_definitions[stream_id])
stream_yaml_spec = yaml.dump(stream_definitions[stream_refId])
print(f"stream spec:\n {stream_yaml_spec}\n\n") # TODO: logger

# build source manifest
Expand Down Expand Up @@ -193,7 +204,7 @@ def api_loader(
mapping_file: pathlib.Path,
openapi_spec_file: pathlib.Path,
output_folder: str,
) -> Tuple[Tuple[str, dict], Tuple[dict, str]]:
) -> Tuple[Tuple[str, PaginationSchemas, dict], Tuple[dict, str], ]:
mappings = dict()
with open(mapping_file, "r") as f:
mappings = yaml.safe_load(f)
Expand All @@ -210,8 +221,16 @@ def api_loader(
defIds = list(mappings.get("definitions", {}).keys())
for refId in ["paging_stream", "single_stream"]:
if refId not in defIds:
print(f"{refId} is missing in mapping definitions") # TODO: logger
raise Exception(f"{refId} stream is missing in mappings")
error_msg = f"{refId} is missing in mapping definitions"
print(error_msg) # TODO: logger
raise Exception(error_msg)

# validate "base" retrievers are in mappings
for refId in ["retriever", "single_retriever"]:
if refId not in defIds:
error_msg = f"{refId} is missing in mapping definitions"
print(error_msg) # TODO: logger
raise Exception(error_msg)

# load openapi spec
(openapi_spec, _) = read_from_filename(openapi_spec_file)
Expand All @@ -229,14 +248,30 @@ def api_loader(
openapi_spec=openapi_spec,
)

# get pagination strategy
definitions = mappings.get("definitions", {})
pagination_strategy = definitions.\
get("retriever", {}).\
get("paginator", {}).\
get("pagination_strategy", {}).\
get("type", None)

if pagination_strategy not in [schema.name for schema in PaginationSchemas]:
error_msg = f"Pagination strategy '{pagination_strategy}' not supported"
print(error_msg) # TODO: logger
raise Exception(error_msg)

pagination_schema = PaginationSchemas[pagination_strategy]
print(f"pagination schema: {pagination_schema.name}") # TODO: logger

# store generated manifest
output_file = f"{output_folder}/{api_name}_source_generated.yaml"
with open(output_file, "w") as out_file:
yaml.dump(source_manifest, out_file)
print(f"source manifest written to {output_file}") # TODO: logger

return (
(api_name, api_parameters),
(api_name, pagination_schema, api_parameters),
(source_manifest, endpoint_text_fields),
)

35 changes: 30 additions & 5 deletions gaianet_rag_api_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,39 @@
from gaianet_rag_api_pipeline.serialize import jsonl_serialize

import pathway as pw
import typing


def pipeline(input_table: pw.Table) -> pw.Table:
"""Your custom logic."""
def pipeline(
endpoints: dict,
stream_tables: typing.List[pw.Table]
) -> pw.Table:
"""Preprocessing each endpoint stream."""

preprocessed_table = preprocessing(input_table)
jsonl_serialize("preprocessed", preprocessed_table) # TODO: remove
chunked_table = chunking(preprocessed_table)
# preprocess and normalize data from each endpoint stream
preprocessed_streams: typing.List[pw.Table] = list()
for i, (_, details) in enumerate(endpoints.items()):
entrypoint_type = details.get("entrypoint_type")
text_properties = details.get("text_properties")
stream = preprocessing(
input_stream=stream_tables[i],
entrypoint_type=entrypoint_type,
text_properties=text_properties
)
# stream = stream_tables[i]
pw.io.jsonlines.write(stream, f"./output/preview-stream{i}.jsonl")
preprocessed_streams.append(stream)

# concat data from all endpoint streams
master_table = None
for stream in preprocessed_streams:
if not master_table:
master_table = stream
continue
master_table = master_table.concat_reindex(stream)

jsonl_serialize("preprocessed", master_table) # TODO: remove
chunked_table = chunking(master_table)
embeddings_table = embeddings(chunked_table)

return embeddings_table
Loading

0 comments on commit 13afb0a

Please sign in to comment.