From 13afb0abe604e4862db663c988cc69a107fa4879 Mon Sep 17 00:00:00 2001 From: Santiago Gonzalez Date: Sat, 31 Aug 2024 01:10:30 -0500 Subject: [PATCH] implement stream preprocessors --- config/mappings.yaml | 42 +++++----- gaianet_rag_api_pipeline/input.py | 91 ++++++++++++++------- gaianet_rag_api_pipeline/loader.py | 79 +++++++++++++----- gaianet_rag_api_pipeline/pipeline.py | 35 ++++++-- gaianet_rag_api_pipeline/preprocessing.py | 29 +++++-- gaianet_rag_api_pipeline/schema/__init__.py | 9 +- gaianet_rag_api_pipeline/schema/base.py | 22 ++++- gaianet_rag_api_pipeline/udfs/__init__.py | 41 +++++----- gaianet_rag_api_pipeline/udfs/json.py | 55 +++++++++++++ output/cache/.gitkeep | 0 run.py | 22 +++-- 11 files changed, 305 insertions(+), 120 deletions(-) create mode 100644 gaianet_rag_api_pipeline/udfs/json.py create mode 100644 output/cache/.gitkeep diff --git a/config/mappings.yaml b/config/mappings.yaml index 6c997b8..4cd8306 100644 --- a/config/mappings.yaml +++ b/config/mappings.yaml @@ -164,37 +164,37 @@ textSchemas: Protocol: type: object properties: - cname: - type: string - name: - type: string - categories: - type: array - items: + - cname: + type: string + - name: type: string + - categories: + type: array + items: + type: string Proposal: type: object properties: - title: - type: string - content: - type: string - summary: - type: string + - title: + type: string + - content: + type: string + - summary: + type: string DiscourseTopic: type: object properties: - title: - type: string + - title: + type: string DiscourseCategory: type: object properties: - name: - type: string - description: - type: string + - name: + type: string + - description: + type: string DiscourseTopicPost: type: object properties: - body: - type: string + - body: + type: string diff --git a/gaianet_rag_api_pipeline/input.py b/gaianet_rag_api_pipeline/input.py index 85f16d9..1327c60 100644 --- a/gaianet_rag_api_pipeline/input.py +++ b/gaianet_rag_api_pipeline/input.py @@ -1,6 +1,6 @@ from gaianet_rag_api_pipeline.config import get_settings, Settings from gaianet_rag_api_pipeline.io.airbyte import AirbyteAPIConnector -from gaianet_rag_api_pipeline.schema import BoardroomAPI # TODO: +from gaianet_rag_api_pipeline.schema import PaginationSchemas import airbyte as ab import pathlib @@ -8,50 +8,81 @@ import typing -def input( +def create_endpoint_stream( api_name: str, - settings: Settings, + cache_name: str, + cache_dir: str, + source_manifest: dict | pathlib.Path, + stream_id: str, + pagination_schema: PaginationSchemas, config: dict[str, typing.Any] | None = None, - source_manifest: bool | dict | pathlib.Path | str = False, - endpoint_text_fields: dict = dict(), - streams: str | list[str] | None = None, force_full_refresh: bool = False -): - # settings = get_settings() - # class InputSchema(pw.Schema): - # value: int - # format="json" - # return pw.io.python.read( - # InfiniteStream(), - # schema=InputSchema, - # format=format, - # autocommit_duration_ms=get_settings().autocommit_duration_ms, - # ) - +) -> pw.Table: cache = ab.caches.new_local_cache( - cache_name=f"{api_name}_cache", - cache_dir='./connector_cache', # NOTICE: parametrize - cleanup=force_full_refresh # NOTICE: CLI param + cache_name=cache_name, + cache_dir=cache_dir, + cleanup=force_full_refresh ) - api_connector = AirbyteAPIConnector( + stream_connector = AirbyteAPIConnector( name=api_name, cache=cache, config=config, source_manifest=source_manifest, - # streams="proposals", # NOTICE: doing just an individual stream - streams=streams, - force_full_refresh=force_full_refresh, # NOTICE: CLI param - check_source=False, # NOTICE: enforce to not check manifest during initialization + streams=stream_id, + force_full_refresh=force_full_refresh, + check_source=False, ) try: - api_connector.check() + stream_connector.check() except Exception as error: - print("FATAL error: manifest error", error) # TODO: logger + print(f"FATAL error: manifest error when creating {stream_id} stream", error) # TODO: logger raise error return pw.io.python.read( - api_connector, - schema=BoardroomAPI # TODO: get this from mappings? or define standard base schema + stream_connector, + schema=pagination_schema.value ) + + +def input( + api_name: str, + settings: Settings, + source_manifest: dict | pathlib.Path, + endpoints: dict, + pagination_schema: PaginationSchemas, + cache_dir: str = "./output/cache", + config: dict[str, typing.Any] | None = None, + force_full_refresh: bool = False +) -> typing.List[pw.Table]: + # settings = get_settings() + # class InputSchema(pw.Schema): + # value: int + # format="json" + # return pw.io.python.read( + # InfiniteStream(), + # schema=InputSchema, + # format=format, + # autocommit_duration_ms=get_settings().autocommit_duration_ms, + # ) + + streams = [details.get("stream_id", "") for endpoint, details in endpoints.items()] + + print(f"input streams - {streams}") # TODO: logger + + stream_tables = list() + for stream_id in streams: + stream = create_endpoint_stream( + api_name=api_name, + cache_name=f"{api_name}_{stream_id}_cache", + cache_dir=cache_dir, + source_manifest=source_manifest, + stream_id=stream_id, + pagination_schema=pagination_schema, + config=config, + force_full_refresh=force_full_refresh, + ) + stream_tables.append(stream) + + return stream_tables diff --git a/gaianet_rag_api_pipeline/loader.py b/gaianet_rag_api_pipeline/loader.py index 53cc025..0fe33e1 100644 --- a/gaianet_rag_api_pipeline/loader.py +++ b/gaianet_rag_api_pipeline/loader.py @@ -1,3 +1,4 @@ +from gaianet_rag_api_pipeline.schema import PaginationSchemas from gaianet_rag_api_pipeline.utils import resolve_refs from openapi_spec_validator import validate_spec @@ -109,33 +110,45 @@ def generate_source_manifest( entrypoint_type = data_root.get("type", "") ##### - # validate textSchemas exist in response schema - # fields specified in textSchemas will be extracted to be preprocessed - # other response data fields will be included as json metadata - text_properties = details.\ - get("textSchema", {}).\ - get("properties", {}) - data_fields = resolve_refs(data_root, openapi_spec) - # get response data fields + data_fields = resolve_refs(data_root, openapi_spec) if entrypoint_type == "array": # need to get properties from nested items spec data_fields = data_fields.\ get("items", {}) data_fields = data_fields.get("properties", {}) fields_list = list(data_fields.keys()) - print(f"endpoint text fields: {text_properties}") # TODO: logger print(f"endpoint spec data fields: {data_fields}") # TODO: logger + # should validate textSchemas exist in response schema + # fields specified in textSchemas will be extracted to be preprocessed + # other response data fields will be included as json metadata + text_schemas = details.\ + get("textSchema", {}).\ + get("properties", {}) + # validate text properties are in the endpoint openapi spec as response data fields - for field, props in text_properties.items(): + text_properties = list() + for text_schema in text_schemas: + field = list(text_schema.keys())[0] + props = text_schema[field] if field not in fields_list or props.get("type", "") != data_fields[field].get("type", ""): error_msg = f"endpoint field not found or mismatch in openapi spec: {field} - {props}" print(error_msg) # TODO: logger raise Exception(error_msg) + text_properties.append(dict(field=field, **props)) + print(f"endpoint text fields: {text_properties}") # TODO: logger + + # build endpoints ids + stream_id = details.get("id", "") + stream_refId = f"{stream_id}_stream" # update endpoint pre-process text fields - endpoint_text_fields[endpoint] = text_properties + endpoint_text_fields[endpoint] = dict( + stream_id=stream_id, + entrypoint_type=entrypoint_type, + text_properties=text_properties, + ) # setup pagination needs_pagination = entrypoint_type == "array" @@ -143,9 +156,7 @@ def generate_source_manifest( ##### # build endpoint stream definition - endpoint_id = details.get("id", "") - stream_id = f"{endpoint_id}_stream" - stream_definitions[stream_id] = { + stream_definitions[stream_refId] = { "$ref": f"#/definitions/{'paging_stream' if needs_pagination else 'single_stream'}" , "schema_loader": { "type": "InlineSchemaLoader", @@ -154,17 +165,17 @@ def generate_source_manifest( }, }, "$parameters": { - "name": endpoint_id, + "name": stream_id, "primary_key": response_primary_key, "path": f'"{endpoint_path}"', }, **request_options } - stream_refs.append(f"#/definitions/{stream_id}") - stream_names.append(endpoint_id) + stream_refs.append(f"#/definitions/{stream_refId}") + stream_names.append(stream_id) - stream_yaml_spec = yaml.dump(stream_definitions[stream_id]) + stream_yaml_spec = yaml.dump(stream_definitions[stream_refId]) print(f"stream spec:\n {stream_yaml_spec}\n\n") # TODO: logger # build source manifest @@ -193,7 +204,7 @@ def api_loader( mapping_file: pathlib.Path, openapi_spec_file: pathlib.Path, output_folder: str, -) -> Tuple[Tuple[str, dict], Tuple[dict, str]]: +) -> Tuple[Tuple[str, PaginationSchemas, dict], Tuple[dict, str], ]: mappings = dict() with open(mapping_file, "r") as f: mappings = yaml.safe_load(f) @@ -210,8 +221,16 @@ def api_loader( defIds = list(mappings.get("definitions", {}).keys()) for refId in ["paging_stream", "single_stream"]: if refId not in defIds: - print(f"{refId} is missing in mapping definitions") # TODO: logger - raise Exception(f"{refId} stream is missing in mappings") + error_msg = f"{refId} is missing in mapping definitions" + print(error_msg) # TODO: logger + raise Exception(error_msg) + + # validate "base" retrievers are in mappings + for refId in ["retriever", "single_retriever"]: + if refId not in defIds: + error_msg = f"{refId} is missing in mapping definitions" + print(error_msg) # TODO: logger + raise Exception(error_msg) # load openapi spec (openapi_spec, _) = read_from_filename(openapi_spec_file) @@ -229,6 +248,22 @@ def api_loader( openapi_spec=openapi_spec, ) + # get pagination strategy + definitions = mappings.get("definitions", {}) + pagination_strategy = definitions.\ + get("retriever", {}).\ + get("paginator", {}).\ + get("pagination_strategy", {}).\ + get("type", None) + + if pagination_strategy not in [schema.name for schema in PaginationSchemas]: + error_msg = f"Pagination strategy '{pagination_strategy}' not supported" + print(error_msg) # TODO: logger + raise Exception(error_msg) + + pagination_schema = PaginationSchemas[pagination_strategy] + print(f"pagination schema: {pagination_schema.name}") # TODO: logger + # store generated manifest output_file = f"{output_folder}/{api_name}_source_generated.yaml" with open(output_file, "w") as out_file: @@ -236,7 +271,7 @@ def api_loader( print(f"source manifest written to {output_file}") # TODO: logger return ( - (api_name, api_parameters), + (api_name, pagination_schema, api_parameters), (source_manifest, endpoint_text_fields), ) diff --git a/gaianet_rag_api_pipeline/pipeline.py b/gaianet_rag_api_pipeline/pipeline.py index c31b278..71d3f9e 100644 --- a/gaianet_rag_api_pipeline/pipeline.py +++ b/gaianet_rag_api_pipeline/pipeline.py @@ -4,14 +4,39 @@ from gaianet_rag_api_pipeline.serialize import jsonl_serialize import pathway as pw +import typing -def pipeline(input_table: pw.Table) -> pw.Table: - """Your custom logic.""" +def pipeline( + endpoints: dict, + stream_tables: typing.List[pw.Table] +) -> pw.Table: + """Preprocessing each endpoint stream.""" - preprocessed_table = preprocessing(input_table) - jsonl_serialize("preprocessed", preprocessed_table) # TODO: remove - chunked_table = chunking(preprocessed_table) + # preprocess and normalize data from each endpoint stream + preprocessed_streams: typing.List[pw.Table] = list() + for i, (_, details) in enumerate(endpoints.items()): + entrypoint_type = details.get("entrypoint_type") + text_properties = details.get("text_properties") + stream = preprocessing( + input_stream=stream_tables[i], + entrypoint_type=entrypoint_type, + text_properties=text_properties + ) + # stream = stream_tables[i] + pw.io.jsonlines.write(stream, f"./output/preview-stream{i}.jsonl") + preprocessed_streams.append(stream) + + # concat data from all endpoint streams + master_table = None + for stream in preprocessed_streams: + if not master_table: + master_table = stream + continue + master_table = master_table.concat_reindex(stream) + + jsonl_serialize("preprocessed", master_table) # TODO: remove + chunked_table = chunking(master_table) embeddings_table = embeddings(chunked_table) return embeddings_table diff --git a/gaianet_rag_api_pipeline/preprocessing.py b/gaianet_rag_api_pipeline/preprocessing.py index 22f46f4..0d0a077 100644 --- a/gaianet_rag_api_pipeline/preprocessing.py +++ b/gaianet_rag_api_pipeline/preprocessing.py @@ -1,16 +1,27 @@ -from gaianet_rag_api_pipeline.udfs import to_json +from gaianet_rag_api_pipeline.udfs import filter_json, json_concat_fields_with_meta, to_json import pathway as pw -def preprocessing(input_table: pw.Table) -> pw.Table: +def preprocessing( + input_stream: pw.Table, + entrypoint_type: str, + text_properties: list[dict] +) -> pw.Table: # NOTICE: With Airbyte we need to parse data to Json during pre-processing - input_table = input_table.with_columns( - data=to_json(input_table.data) + input_stream = input_stream.with_columns( + data=to_json(pw.this.data) ) - # TODO: should flatten based on endpoint metadata - # case: protocol endpoint doesn't need flattening - # output_table = input_table.flatten(input_table.data) - # return output_table - return input_table + # should flatten results if endpoint returns multiple records + if entrypoint_type == "array": + input_stream = input_stream.flatten(input_stream.data) + + # normalization + preprocess_fields = [p.get("field") for p in text_properties] + output_table = input_stream.with_columns( + content=json_concat_fields_with_meta(pw.this.data, text_properties), + metadata=filter_json(pw.this.data, preprocess_fields), + ) + + return output_table diff --git a/gaianet_rag_api_pipeline/schema/__init__.py b/gaianet_rag_api_pipeline/schema/__init__.py index 7c08159..f9d1fd3 100644 --- a/gaianet_rag_api_pipeline/schema/__init__.py +++ b/gaianet_rag_api_pipeline/schema/__init__.py @@ -1,3 +1,8 @@ -from .base import BoardroomAPI +from .base import CursorBasedAPISchema, OffsetBasedAPISchema, PageBasedAPISchema, PaginationSchemas -__all__ = ["BoardroomAPI"] # TODO: update \ No newline at end of file +__all__ = [ + "CursorBasedAPISchema", + "OffsetBasedAPISchema", + "PageBasedAPISchema", + "PaginationSchemas" +] diff --git a/gaianet_rag_api_pipeline/schema/base.py b/gaianet_rag_api_pipeline/schema/base.py index 2674c25..3e1dad2 100644 --- a/gaianet_rag_api_pipeline/schema/base.py +++ b/gaianet_rag_api_pipeline/schema/base.py @@ -1,3 +1,4 @@ +from enum import Enum from pathway import DateTimeNaive, Json, Schema @@ -5,10 +6,25 @@ class AirbyteSchema(Schema): _airbyte_raw_id: str _airbyte_extracted_at: DateTimeNaive _airbyte_meta: dict + stream: str -# TODO: define general schema name with different pagination strategies -class BoardroomAPI(AirbyteSchema): - stream: str +class CursorBasedAPISchema(AirbyteSchema): data: Json nextcursor: str | None + + +# TODO: complete implementation +class OffsetBasedAPISchema(AirbyteSchema): + data: Json + + +# TODO: complete implementation +class PageBasedAPISchema(AirbyteSchema): + data: Json + + +class PaginationSchemas(Enum): + CursorPagination = CursorBasedAPISchema + OffsetIncrement = OffsetBasedAPISchema + PageIncrement = PageBasedAPISchema diff --git a/gaianet_rag_api_pipeline/udfs/__init__.py b/gaianet_rag_api_pipeline/udfs/__init__.py index 0a09b97..45900cb 100644 --- a/gaianet_rag_api_pipeline/udfs/__init__.py +++ b/gaianet_rag_api_pipeline/udfs/__init__.py @@ -1,4 +1,11 @@ -from gaianet_rag_api_pipeline.udfs.reducers import JSONAccumulator +from .json import ( + filter_json, + json_concat_fields, + json_concat_fields_with_meta, + json_merge, + json_stringify, to_json +) +from .reducers import JSONAccumulator import json import pathway as pw @@ -8,21 +15,17 @@ @pw.udf -def append_parent_id(content: pw.Json, parent_id: str) -> pw.Json: - data = { "parent_id": parent_id, **content.as_dict() } - return data - - -@pw.udf -def to_json(val: pw.Json) -> pw.Json: - return pw.Json(json.loads(val.as_str())) - - -@pw.udf -def filter_document(document: pw.Json, fields: list[str]) -> pw.Json: - data = { **document.as_dict() } - # data = { "refId": document["refId"] } - for field in fields: - if field in data: - data.pop(field) - return data +def concat_fields(separator: str, *fields) -> str: + return f"{separator}".join(fields) + + +__all__ = [ + "concat_fields", + "filter_json", + "json_concat_fields", + "json_concat_fields_with_meta", + "json_merge", + "json_reducer" + "json_stringify", + "to_json", +] diff --git a/gaianet_rag_api_pipeline/udfs/json.py b/gaianet_rag_api_pipeline/udfs/json.py new file mode 100644 index 0000000..ee04245 --- /dev/null +++ b/gaianet_rag_api_pipeline/udfs/json.py @@ -0,0 +1,55 @@ +import json +import pathway as pw +import typing + + +@pw.udf +def filter_json(document: pw.Json, fields_to_remove: list[str]) -> pw.Json: + data = { **document.as_dict() } + # data = { "refId": document["refId"] } + for field in fields_to_remove: + if field in data: + data.pop(field) + return data + + +@pw.udf +def json_concat_fields(data: pw.Json, fields: list[str]) -> str: + values = [data[fname].as_str() for fname in fields] + return "\n\n".join(values) + + +@pw.udf +def json_concat_fields_with_meta(data: pw.Json, fields: list[pw.Json]) -> str: + values = list() + global_field = None + try: + for field_meta in fields: + meta = field_meta.as_dict() + global_field = field_meta + label = meta.get("field") + field_value = data[label] + if not field_value: + continue + if meta.get("type") == "array": + values.append(f"{label}: {','.join(field_value.as_list())}") + else: + values.append(field_value.as_str()) + except Exception as error: + raise Exception(f"FAILED in {global_field} - {data.as_dict()}", error) + return "\n\n".join(values) + + +@pw.udf +def json_merge(base: pw.Json, content: pw.Json) -> pw.Json: + return { **base.as_dict(), **content.as_dict()} + + +@pw.udf +def json_stringify(data: pw.Json) -> str: + return json.dumps(data.as_dict()) + + +@pw.udf +def to_json(val: pw.Json) -> pw.Json: + return pw.Json(json.loads(val.as_str())) \ No newline at end of file diff --git a/output/cache/.gitkeep b/output/cache/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/run.py b/run.py index 3a4547a..7314390 100644 --- a/run.py +++ b/run.py @@ -37,8 +37,8 @@ def run( # TODO: omit source generation if source manifest is specified as optional param ( - (api_name, api_parameters), - (source_manifest, endpoint_text_fields) + (api_name, pagination_schema, api_parameters), + (source_manifest, endpoints) ) = api_loader( mapping_file=pathlib.Path(mapping_manifest_file), openapi_spec_file=pathlib.Path(settings.openapi_spec_file), @@ -46,22 +46,26 @@ def run( ) print(f"api config: {api_name} - {api_parameters}") - # print(f"endpoint_text_fields - {endpoint_text_fields}") + print(f"endpoints - {endpoints}") - input_table = input( + stream_tables = input( api_name=api_name, settings=settings, + endpoints=endpoints, + pagination_schema=pagination_schema, + # source_manifest=pathlib.Path(settings.api_manifest_file), # NOTICE: CLI parma BUT should come as dict after generation + source_manifest=source_manifest, config=dict( api_key=settings.api_key, **api_parameters ), - # source_manifest=pathlib.Path(settings.api_manifest_file), # NOTICE: CLI parma BUT should come as dict after generation - source_manifest=source_manifest, - endpoint_text_fields=endpoint_text_fields, - # streams="proposals", # TODO: for now extract proposals only force_full_refresh=full_refresh # NOTICE: CLI param ) - output_table = pipeline(input_table) + + output_table = pipeline( + endpoints=endpoints, + stream_tables=stream_tables + ) output(output_table) pw.run(monitoring_level=pw.MonitoringLevel.ALL)