From b46bdb21ef9ca541a924399be6f9374d58fc93db Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sat, 21 Dec 2024 20:07:11 +0400 Subject: [PATCH 1/3] feat: custom callback config options and initial tests --- graphai/callback.py | 180 ++++++++++++++++++++++++--- tests/unit/test_callback.py | 242 ++++++++++++++++++++++++++++++++++++ 2 files changed, 407 insertions(+), 15 deletions(-) create mode 100644 tests/unit/test_callback.py diff --git a/graphai/callback.py b/graphai/callback.py index 4de897a..1e9cb45 100644 --- a/graphai/callback.py +++ b/graphai/callback.py @@ -1,4 +1,5 @@ import asyncio +from pydantic import Field from typing import Optional from collections.abc import AsyncIterator from semantic_router.utils.logger import logger @@ -7,20 +8,113 @@ log_stream = True class Callback: - first_token = True - current_node_name: Optional[str] = None - active: bool = True + identifier: str = Field( + default="graphai", + description=( + "The identifier for special tokens. This allows us to easily " + "identify special tokens in the stream so we can handle them " + "correctly in any downstream process." + ) + ) + special_token_format: str = Field( + default="<{identifier}:{token}:{params}>", + description=( + "The format for special tokens. This is used to format special " + "tokens so they can be easily identified in the stream. " + "The format is a string with three possible components:\n" + "- {identifier}: An identifier shared by all special tokens, " + "by default this is 'graphai'.\n" + "- {token}: The special token type to be streamed. This may " + "be a tool name, identifier for start/end nodes, etc.\n" + "- {params}: Any additional parameters to be streamed. The parameters " + "are formatted as a comma-separated list of key-value pairs." + ), + examples=[ + "<{identifier}:{token}:{params}>", + "<[{identifier} | {token} | {params}]>", + "<{token}:{params}>" + ] + ) + token_format: str = Field( + default="{token}", + description=( + "The format for streamed tokens. This is used to format the " + "tokens typically returned from LLMs. By default, no special " + "formatting is applied." + ) + ) + _first_token: bool = Field( + default=True, + description="Whether this is the first token in the stream.", + exclude=True + ) + _current_node_name: Optional[str] = Field( + default=None, + description="The name of the current node.", + exclude=True + ) + _active: bool = Field( + default=True, + description="Whether the callback is active.", + exclude=True + ) + _done: bool = Field( + default=False, + description="Whether the stream is done and should be closed.", + exclude=True + ) queue: asyncio.Queue - def __init__(self): + def __init__( + self, + identifier: str = "graphai", + special_token_format: str = "<{identifier}:{token}:{params}>", + token_format: str = "{token}", + ): + self.identifier = identifier + self.special_token_format = special_token_format + self.token_format = token_format self.queue = asyncio.Queue() + self._done = False + self._first_token = True + self._current_node_name = None + self._active = True + + @property + def first_token(self) -> bool: + return self._first_token + + @first_token.setter + def first_token(self, value: bool): + self._first_token = value + + @property + def current_node_name(self) -> Optional[str]: + return self._current_node_name + + @current_node_name.setter + def current_node_name(self, value: Optional[str]): + self._current_node_name = value + + @property + def active(self) -> bool: + return self._active + + @active.setter + def active(self, value: bool): + self._active = value def __call__(self, token: str, node_name: Optional[str] = None): + if self._done: + raise RuntimeError("Cannot add tokens to a closed stream") self._check_node_name(node_name=node_name) # otherwise we just assume node is correct and send token self.queue.put_nowait(token) async def acall(self, token: str, node_name: Optional[str] = None): + # TODO JB: do we need to have `node_name` param? + if self._done: + raise RuntimeError("Cannot add tokens to a closed stream") self._check_node_name(node_name=node_name) # otherwise we just assume node is correct and send token self.queue.put_nowait(token) @@ -30,29 +124,70 @@ async def aiter(self) -> AsyncIterator[str]: a generator that yields tokens from the queue until the END token is received. """ - while True: - token = await self.queue.get() - yield token - self.queue.task_done() - if token == "": + end_token = await self._build_special_token( + name="END", + params=None + ) + while True: # Keep going until we see the END token + try: + if self._done and self.queue.empty(): + break + token = await self.queue.get() + yield token + self.queue.task_done() + if token == end_token: + break + except asyncio.CancelledError: break + self._done = True # Mark as done after processing all tokens async def start_node(self, node_name: str, active: bool = True): + if self._done: + raise RuntimeError("Cannot start node on a closed stream") self.current_node_name = node_name if self.first_token: - # TODO JB: not sure if we need self.first_token self.first_token = False self.active = active if self.active: - self.queue.put_nowait(f"") + token = await self._build_special_token( + name="start", + params=None + ) + self.queue.put_nowait(token) + # TODO JB: should we use two tokens here? + node_token = await self._build_special_token( + name=self.current_node_name, + params=None + ) + self.queue.put_nowait(node_token) async def end_node(self, node_name: str): - self.current_node_name = None + if self._done: + raise RuntimeError("Cannot end node on a closed stream") + self.current_node_name = node_name if self.active: - self.queue.put_nowait(f"") + node_token = await self._build_special_token( + name=self.current_node_name, + params=None + ) + self.queue.put_nowait(node_token) + # TODO JB: should we use two tokens here or jump to END only? + await self.close() async def close(self): - self.queue.put_nowait("") + """Close the stream and prevent further tokens from being added. + This will send an END token and set the done flag to True. + """ + if self._done: + return + end_token = await self._build_special_token( + name="END", + params=None + ) + self._done = True # Set done before putting the end token + self.queue.put_nowait(end_token) + # Don't wait for queue.join() as it can cause deadlock + # The stream will close when aiter processes the END token def _check_node_name(self, node_name: Optional[str] = None): if node_name: @@ -60,4 +195,19 @@ def _check_node_name(self, node_name: Optional[str] = None): if self.current_node_name != node_name: raise ValueError( f"Node name mismatch: {self.current_node_name} != {node_name}" - ) \ No newline at end of file + ) + + async def _build_special_token(self, name: str, params: dict[str, any] | None = None): + if params: + params_str = ",".join([f"{k}={v}" for k, v in params.items()]) + else: + params_str = "" + if self.identifier: + identifier = self.identifier + else: + identifier = "" + return self.special_token_format.format( + identifier=identifier, + token=name, + params=params_str + ) diff --git a/tests/unit/test_callback.py b/tests/unit/test_callback.py new file mode 100644 index 0000000..412e49a --- /dev/null +++ b/tests/unit/test_callback.py @@ -0,0 +1,242 @@ +import pytest +import asyncio +from graphai.callback import Callback +from graphai import node, Graph +@pytest.fixture +async def callback(): + cb = Callback() + yield cb + await cb.close() + +@pytest.fixture +async def define_graph(): + """Define a graph with nodes that stream and don't stream. + """ + @node(start=True) + async def node_start(input: str): + # no stream added here + return {"input": input} + + @node(stream=True) + async def node_a(input: str, callback: Callback): + tokens = ["Hello", " ", "World", "!"] + for token in tokens: + await callback.acall(token) + return {"input": input} + + @node(stream=True) + async def node_b(input: str, callback: Callback): + tokens = ["Here", " ", "is", " ", "node", " ", "B", "!"] + for token in tokens: + await callback.acall(token) + return {"input": input} + + @node + async def node_c(input: str): + # no stream added here + return {"input": input} + + @node(stream=True) + async def node_d(input: str, callback: Callback): + tokens = ["Here", " ", "is", " ", "node", " ", "D", "!"] + for token in tokens: + await callback.acall(token) + return {"input": input} + + @node(end=True) + async def node_end(input: str): + return {"input": input} + + graph = Graph() + + nodes = [node_start, node_a, node_b, node_c, node_d, node_end] + + for i, node_fn in enumerate(nodes): + graph.add_node(node_fn) + if i > 0: + graph.add_edge(nodes[i-1], node_fn) + + return graph + +async def stream(cb: Callback, text: str): + tokens = text.split(" ") + for token in tokens: + await cb.acall(token) + await cb.close() + return + +class TestCallbackConfig: + @pytest.mark.asyncio + async def test_callback_initialization(self): + """Test basic initialization of Callback class""" + cb = Callback() + assert cb.identifier == "graphai" + assert cb.special_token_format == "<{identifier}:{token}:{params}>" + assert cb.token_format == "{token}" + assert isinstance(cb.queue, asyncio.Queue) + + @pytest.mark.asyncio + async def test_custom_initialization(self): + """Test initialization with custom parameters""" + cb = Callback( + identifier="custom", + special_token_format="[{identifier}:{token}:{params}]", + token_format="<<{token}>>" + ) + assert cb.identifier == "custom" + assert cb.special_token_format == "[{identifier}:{token}:{params}]" + assert cb.token_format == "<<{token}>>" + # create streaming task + response = asyncio.create_task(stream(cb, "Hello")) + out_tokens = [] + # now stream + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == ["Hello", "[custom:END:]"] + + @pytest.mark.asyncio + async def test_default_tokens(self): + """Test default tokens""" + cb = Callback() + # create streaming task + response = asyncio.create_task(stream(cb, "Hello")) + out_tokens = [] + # now stream + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == ["Hello", ""] + + @pytest.mark.asyncio + async def test_custom_tokens(self): + """Test custom tokens""" + cb = Callback( + identifier="custom", + special_token_format="[{identifier}:{token}:{params}]", + token_format="<<{token}>>" + ) + # create streaming task + response = asyncio.create_task(stream(cb, "Hello")) + out_tokens = [] + # now stream + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == ["Hello", "[custom:END:]"] + + +# @pytest.mark.asyncio +# async def test_start_node(callback): +# """Test starting a node""" +# await callback.start_node("test_node") +# token = callback.queue.get_nowait() +# assert token == "" + +# @pytest.mark.asyncio +# async def test_end_node(callback): +# """Test ending a node""" +# await callback.start_node("test_node") +# await callback.end_node("test_node") +# # Get and discard the start token +# _ = callback.queue.get_nowait() +# token = callback.queue.get_nowait() +# assert token == "" + +# @pytest.mark.asyncio +# async def test_node_name_mismatch(callback): +# """Test node name mismatch error""" +# await callback.start_node("node1") +# with pytest.raises(ValueError, match="Node name mismatch"): +# callback("test token", node_name="node2") + +# @pytest.mark.asyncio +# async def test_token_streaming(callback): +# """Test basic token streaming""" +# await callback.start_node("test_node") +# test_tokens = ["Hello", " ", "World", "!"] + +# for token in test_tokens: +# await callback.acall(token, node_name="test_node") + +# # Skip the start node token +# _ = callback.queue.get_nowait() + +# # Check each streamed token +# for expected_token in test_tokens: +# token = callback.queue.get_nowait() +# assert token == expected_token + +# @pytest.mark.asyncio +# async def test_aiter_streaming(callback): +# """Test async iteration over tokens""" +# test_tokens = ["Hello", " ", "World", "!"] + +# await callback.start_node("test_node") +# for token in test_tokens: +# await callback.acall(token, node_name="test_node") +# await callback.end_node("test_node") +# await callback.close() + +# received_tokens = [] +# async for token in callback.aiter(): +# received_tokens.append(token) + +# assert len(received_tokens) == len(test_tokens) + 3 # +3 for start, end, and END tokens +# assert received_tokens[0] == "" +# assert received_tokens[-2] == "" +# assert received_tokens[-1] == "" +# assert received_tokens[1:-2] == test_tokens + +# @pytest.mark.asyncio +# async def test_inactive_node(callback): +# """Test behavior when node is inactive""" +# await callback.start_node("test_node", active=False) +# await callback.acall("This shouldn't be queued", node_name="test_node") +# await callback.end_node("test_node") + +# with pytest.raises(asyncio.QueueEmpty): +# callback.queue.get_nowait() + +# @pytest.mark.asyncio +# async def test_build_special_token(callback): +# """Test building special tokens with parameters""" +# token = await callback._build_special_token( +# "test", +# params={"key": "value", "number": 42} +# ) +# assert token == "" + +# # Test with no params +# token = await callback._build_special_token("test") +# assert token == "" + +# @pytest.mark.asyncio +# async def test_close(callback): +# """Test closing the callback""" +# await callback.close() +# token = callback.queue.get_nowait() +# assert token == "" + +# @pytest.mark.asyncio +# async def test_sequential_nodes(callback): +# """Test handling multiple sequential nodes""" +# # First node +# await callback.start_node("node1") +# await callback.acall("token1", node_name="node1") +# await callback.end_node("node1") + +# # Second node +# await callback.start_node("node2") +# await callback.acall("token2", node_name="node2") +# await callback.end_node("node2") + +# expected_sequence = [ +# "", +# "token1", +# "", +# "", +# "token2", +# "" +# ] + +# for expected in expected_sequence: +# token = callback.queue.get_nowait() +# assert token == expected From 7bf0620ddffed70b627135757315fd1efa89c35b Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sat, 21 Dec 2024 21:02:24 +0400 Subject: [PATCH 2/3] fix: streaming all nodes and tokens --- graphai/callback.py | 12 +++++---- tests/unit/test_callback.py | 54 ++++++++++++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/graphai/callback.py b/graphai/callback.py index 1e9cb45..4a98805 100644 --- a/graphai/callback.py +++ b/graphai/callback.py @@ -142,6 +142,8 @@ async def aiter(self) -> AsyncIterator[str]: self._done = True # Mark as done after processing all tokens async def start_node(self, node_name: str, active: bool = True): + """Starts a new node and emits the start token. + """ if self._done: raise RuntimeError("Cannot start node on a closed stream") self.current_node_name = node_name @@ -150,7 +152,7 @@ async def start_node(self, node_name: str, active: bool = True): self.active = active if self.active: token = await self._build_special_token( - name="start", + name=f"{self.current_node_name}:start", params=None ) self.queue.put_nowait(token) @@ -162,17 +164,17 @@ async def start_node(self, node_name: str, active: bool = True): self.queue.put_nowait(node_token) async def end_node(self, node_name: str): + """Emits the end token for the current node. + """ if self._done: raise RuntimeError("Cannot end node on a closed stream") - self.current_node_name = node_name + #self.current_node_name = node_name if self.active: node_token = await self._build_special_token( - name=self.current_node_name, + name=f"{self.current_node_name}:end", params=None ) self.queue.put_nowait(node_token) - # TODO JB: should we use two tokens here or jump to END only? - await self.close() async def close(self): """Close the stream and prevent further tokens from being added. diff --git a/tests/unit/test_callback.py b/tests/unit/test_callback.py index 412e49a..6834d4e 100644 --- a/tests/unit/test_callback.py +++ b/tests/unit/test_callback.py @@ -19,14 +19,14 @@ async def node_start(input: str): @node(stream=True) async def node_a(input: str, callback: Callback): - tokens = ["Hello", " ", "World", "!"] + tokens = ["Hello", "World", "!"] for token in tokens: await callback.acall(token) return {"input": input} @node(stream=True) async def node_b(input: str, callback: Callback): - tokens = ["Here", " ", "is", " ", "node", " ", "B", "!"] + tokens = ["Here", "is", "node", "B", "!"] for token in tokens: await callback.acall(token) return {"input": input} @@ -38,7 +38,7 @@ async def node_c(input: str): @node(stream=True) async def node_d(input: str, callback: Callback): - tokens = ["Here", " ", "is", " ", "node", " ", "D", "!"] + tokens = ["Here", "is", "node", "D", "!"] for token in tokens: await callback.acall(token) return {"input": input} @@ -55,6 +55,8 @@ async def node_end(input: str): graph.add_node(node_fn) if i > 0: graph.add_edge(nodes[i-1], node_fn) + + graph.compile() return graph @@ -87,7 +89,7 @@ async def test_custom_initialization(self): assert cb.special_token_format == "[{identifier}:{token}:{params}]" assert cb.token_format == "<<{token}>>" # create streaming task - response = asyncio.create_task(stream(cb, "Hello")) + asyncio.create_task(stream(cb, "Hello")) out_tokens = [] # now stream async for token in cb.aiter(): @@ -99,7 +101,7 @@ async def test_default_tokens(self): """Test default tokens""" cb = Callback() # create streaming task - response = asyncio.create_task(stream(cb, "Hello")) + asyncio.create_task(stream(cb, "Hello")) out_tokens = [] # now stream async for token in cb.aiter(): @@ -115,13 +117,51 @@ async def test_custom_tokens(self): token_format="<<{token}>>" ) # create streaming task - response = asyncio.create_task(stream(cb, "Hello")) + asyncio.create_task(stream(cb, "Hello")) out_tokens = [] # now stream async for token in cb.aiter(): out_tokens.append(token) assert out_tokens == ["Hello", "[custom:END:]"] - + + +class TestCallbackGraph: + @pytest.mark.asyncio + async def test_callback_graph(self, define_graph): + """Test callback graph""" + graph = await define_graph + cb = graph.get_callback() + asyncio.create_task(graph.execute( + input={"input": "Hello"} + )) + out_tokens = [] + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == [ + "", + "", + "Hello", + "World", + "!", + "", + "", + "", + "Here", + "is", + "node", + "B", + "!", + "", + "", + "", + "Here", + "is", + "node", + "D", + "!", + "", + "" + ] # @pytest.mark.asyncio # async def test_start_node(callback): From 2c171cf705300faee4f63880672660b09baae7cd Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sun, 22 Dec 2024 18:44:20 +0400 Subject: [PATCH 3/3] feat: custom callback graph test --- tests/unit/test_callback.py | 40 +++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/unit/test_callback.py b/tests/unit/test_callback.py index 6834d4e..e27e156 100644 --- a/tests/unit/test_callback.py +++ b/tests/unit/test_callback.py @@ -163,6 +163,46 @@ async def test_callback_graph(self, define_graph): "" ] + @pytest.mark.asyncio + async def test_custom_callback_graph(self, define_graph): + """Test callback graph""" + graph = await define_graph + cb = graph.get_callback() + cb.identifier = "custom" + cb.special_token_format = "[{identifier}:{token}:{params}]" + cb.token_format = "<<{token}>>" + asyncio.create_task(graph.execute( + input={"input": "Hello"} + )) + out_tokens = [] + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == [ + "[custom:node_a:start:]", + "[custom:node_a:]", + "Hello", + "World", + "!", + "[custom:node_a:end:]", + "[custom:node_b:start:]", + "[custom:node_b:]", + "Here", + "is", + "node", + "B", + "!", + "[custom:node_b:end:]", + "[custom:node_d:start:]", + "[custom:node_d:]", + "Here", + "is", + "node", + "D", + "!", + "[custom:node_d:end:]", + "[custom:END:]" + ] + # @pytest.mark.asyncio # async def test_start_node(callback): # """Test starting a node"""