diff --git a/graphai/callback.py b/graphai/callback.py index 4de897a..4a98805 100644 --- a/graphai/callback.py +++ b/graphai/callback.py @@ -1,4 +1,5 @@ import asyncio +from pydantic import Field from typing import Optional from collections.abc import AsyncIterator from semantic_router.utils.logger import logger @@ -7,20 +8,113 @@ log_stream = True class Callback: - first_token = True - current_node_name: Optional[str] = None - active: bool = True + identifier: str = Field( + default="graphai", + description=( + "The identifier for special tokens. This allows us to easily " + "identify special tokens in the stream so we can handle them " + "correctly in any downstream process." + ) + ) + special_token_format: str = Field( + default="<{identifier}:{token}:{params}>", + description=( + "The format for special tokens. This is used to format special " + "tokens so they can be easily identified in the stream. " + "The format is a string with three possible components:\n" + "- {identifier}: An identifier shared by all special tokens, " + "by default this is 'graphai'.\n" + "- {token}: The special token type to be streamed. This may " + "be a tool name, identifier for start/end nodes, etc.\n" + "- {params}: Any additional parameters to be streamed. The parameters " + "are formatted as a comma-separated list of key-value pairs." + ), + examples=[ + "<{identifier}:{token}:{params}>", + "<[{identifier} | {token} | {params}]>", + "<{token}:{params}>" + ] + ) + token_format: str = Field( + default="{token}", + description=( + "The format for streamed tokens. This is used to format the " + "tokens typically returned from LLMs. By default, no special " + "formatting is applied." + ) + ) + _first_token: bool = Field( + default=True, + description="Whether this is the first token in the stream.", + exclude=True + ) + _current_node_name: Optional[str] = Field( + default=None, + description="The name of the current node.", + exclude=True + ) + _active: bool = Field( + default=True, + description="Whether the callback is active.", + exclude=True + ) + _done: bool = Field( + default=False, + description="Whether the stream is done and should be closed.", + exclude=True + ) queue: asyncio.Queue - def __init__(self): + def __init__( + self, + identifier: str = "graphai", + special_token_format: str = "<{identifier}:{token}:{params}>", + token_format: str = "{token}", + ): + self.identifier = identifier + self.special_token_format = special_token_format + self.token_format = token_format self.queue = asyncio.Queue() + self._done = False + self._first_token = True + self._current_node_name = None + self._active = True + + @property + def first_token(self) -> bool: + return self._first_token + + @first_token.setter + def first_token(self, value: bool): + self._first_token = value + + @property + def current_node_name(self) -> Optional[str]: + return self._current_node_name + + @current_node_name.setter + def current_node_name(self, value: Optional[str]): + self._current_node_name = value + + @property + def active(self) -> bool: + return self._active + + @active.setter + def active(self, value: bool): + self._active = value def __call__(self, token: str, node_name: Optional[str] = None): + if self._done: + raise RuntimeError("Cannot add tokens to a closed stream") self._check_node_name(node_name=node_name) # otherwise we just assume node is correct and send token self.queue.put_nowait(token) async def acall(self, token: str, node_name: Optional[str] = None): + # TODO JB: do we need to have `node_name` param? + if self._done: + raise RuntimeError("Cannot add tokens to a closed stream") self._check_node_name(node_name=node_name) # otherwise we just assume node is correct and send token self.queue.put_nowait(token) @@ -30,29 +124,72 @@ async def aiter(self) -> AsyncIterator[str]: a generator that yields tokens from the queue until the END token is received. """ - while True: - token = await self.queue.get() - yield token - self.queue.task_done() - if token == "": + end_token = await self._build_special_token( + name="END", + params=None + ) + while True: # Keep going until we see the END token + try: + if self._done and self.queue.empty(): + break + token = await self.queue.get() + yield token + self.queue.task_done() + if token == end_token: + break + except asyncio.CancelledError: break + self._done = True # Mark as done after processing all tokens async def start_node(self, node_name: str, active: bool = True): + """Starts a new node and emits the start token. + """ + if self._done: + raise RuntimeError("Cannot start node on a closed stream") self.current_node_name = node_name if self.first_token: - # TODO JB: not sure if we need self.first_token self.first_token = False self.active = active if self.active: - self.queue.put_nowait(f"") + token = await self._build_special_token( + name=f"{self.current_node_name}:start", + params=None + ) + self.queue.put_nowait(token) + # TODO JB: should we use two tokens here? + node_token = await self._build_special_token( + name=self.current_node_name, + params=None + ) + self.queue.put_nowait(node_token) async def end_node(self, node_name: str): - self.current_node_name = None + """Emits the end token for the current node. + """ + if self._done: + raise RuntimeError("Cannot end node on a closed stream") + #self.current_node_name = node_name if self.active: - self.queue.put_nowait(f"") + node_token = await self._build_special_token( + name=f"{self.current_node_name}:end", + params=None + ) + self.queue.put_nowait(node_token) async def close(self): - self.queue.put_nowait("") + """Close the stream and prevent further tokens from being added. + This will send an END token and set the done flag to True. + """ + if self._done: + return + end_token = await self._build_special_token( + name="END", + params=None + ) + self._done = True # Set done before putting the end token + self.queue.put_nowait(end_token) + # Don't wait for queue.join() as it can cause deadlock + # The stream will close when aiter processes the END token def _check_node_name(self, node_name: Optional[str] = None): if node_name: @@ -60,4 +197,19 @@ def _check_node_name(self, node_name: Optional[str] = None): if self.current_node_name != node_name: raise ValueError( f"Node name mismatch: {self.current_node_name} != {node_name}" - ) \ No newline at end of file + ) + + async def _build_special_token(self, name: str, params: dict[str, any] | None = None): + if params: + params_str = ",".join([f"{k}={v}" for k, v in params.items()]) + else: + params_str = "" + if self.identifier: + identifier = self.identifier + else: + identifier = "" + return self.special_token_format.format( + identifier=identifier, + token=name, + params=params_str + ) diff --git a/tests/unit/test_callback.py b/tests/unit/test_callback.py new file mode 100644 index 0000000..e27e156 --- /dev/null +++ b/tests/unit/test_callback.py @@ -0,0 +1,322 @@ +import pytest +import asyncio +from graphai.callback import Callback +from graphai import node, Graph +@pytest.fixture +async def callback(): + cb = Callback() + yield cb + await cb.close() + +@pytest.fixture +async def define_graph(): + """Define a graph with nodes that stream and don't stream. + """ + @node(start=True) + async def node_start(input: str): + # no stream added here + return {"input": input} + + @node(stream=True) + async def node_a(input: str, callback: Callback): + tokens = ["Hello", "World", "!"] + for token in tokens: + await callback.acall(token) + return {"input": input} + + @node(stream=True) + async def node_b(input: str, callback: Callback): + tokens = ["Here", "is", "node", "B", "!"] + for token in tokens: + await callback.acall(token) + return {"input": input} + + @node + async def node_c(input: str): + # no stream added here + return {"input": input} + + @node(stream=True) + async def node_d(input: str, callback: Callback): + tokens = ["Here", "is", "node", "D", "!"] + for token in tokens: + await callback.acall(token) + return {"input": input} + + @node(end=True) + async def node_end(input: str): + return {"input": input} + + graph = Graph() + + nodes = [node_start, node_a, node_b, node_c, node_d, node_end] + + for i, node_fn in enumerate(nodes): + graph.add_node(node_fn) + if i > 0: + graph.add_edge(nodes[i-1], node_fn) + + graph.compile() + + return graph + +async def stream(cb: Callback, text: str): + tokens = text.split(" ") + for token in tokens: + await cb.acall(token) + await cb.close() + return + +class TestCallbackConfig: + @pytest.mark.asyncio + async def test_callback_initialization(self): + """Test basic initialization of Callback class""" + cb = Callback() + assert cb.identifier == "graphai" + assert cb.special_token_format == "<{identifier}:{token}:{params}>" + assert cb.token_format == "{token}" + assert isinstance(cb.queue, asyncio.Queue) + + @pytest.mark.asyncio + async def test_custom_initialization(self): + """Test initialization with custom parameters""" + cb = Callback( + identifier="custom", + special_token_format="[{identifier}:{token}:{params}]", + token_format="<<{token}>>" + ) + assert cb.identifier == "custom" + assert cb.special_token_format == "[{identifier}:{token}:{params}]" + assert cb.token_format == "<<{token}>>" + # create streaming task + asyncio.create_task(stream(cb, "Hello")) + out_tokens = [] + # now stream + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == ["Hello", "[custom:END:]"] + + @pytest.mark.asyncio + async def test_default_tokens(self): + """Test default tokens""" + cb = Callback() + # create streaming task + asyncio.create_task(stream(cb, "Hello")) + out_tokens = [] + # now stream + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == ["Hello", ""] + + @pytest.mark.asyncio + async def test_custom_tokens(self): + """Test custom tokens""" + cb = Callback( + identifier="custom", + special_token_format="[{identifier}:{token}:{params}]", + token_format="<<{token}>>" + ) + # create streaming task + asyncio.create_task(stream(cb, "Hello")) + out_tokens = [] + # now stream + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == ["Hello", "[custom:END:]"] + + +class TestCallbackGraph: + @pytest.mark.asyncio + async def test_callback_graph(self, define_graph): + """Test callback graph""" + graph = await define_graph + cb = graph.get_callback() + asyncio.create_task(graph.execute( + input={"input": "Hello"} + )) + out_tokens = [] + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == [ + "", + "", + "Hello", + "World", + "!", + "", + "", + "", + "Here", + "is", + "node", + "B", + "!", + "", + "", + "", + "Here", + "is", + "node", + "D", + "!", + "", + "" + ] + + @pytest.mark.asyncio + async def test_custom_callback_graph(self, define_graph): + """Test callback graph""" + graph = await define_graph + cb = graph.get_callback() + cb.identifier = "custom" + cb.special_token_format = "[{identifier}:{token}:{params}]" + cb.token_format = "<<{token}>>" + asyncio.create_task(graph.execute( + input={"input": "Hello"} + )) + out_tokens = [] + async for token in cb.aiter(): + out_tokens.append(token) + assert out_tokens == [ + "[custom:node_a:start:]", + "[custom:node_a:]", + "Hello", + "World", + "!", + "[custom:node_a:end:]", + "[custom:node_b:start:]", + "[custom:node_b:]", + "Here", + "is", + "node", + "B", + "!", + "[custom:node_b:end:]", + "[custom:node_d:start:]", + "[custom:node_d:]", + "Here", + "is", + "node", + "D", + "!", + "[custom:node_d:end:]", + "[custom:END:]" + ] + +# @pytest.mark.asyncio +# async def test_start_node(callback): +# """Test starting a node""" +# await callback.start_node("test_node") +# token = callback.queue.get_nowait() +# assert token == "" + +# @pytest.mark.asyncio +# async def test_end_node(callback): +# """Test ending a node""" +# await callback.start_node("test_node") +# await callback.end_node("test_node") +# # Get and discard the start token +# _ = callback.queue.get_nowait() +# token = callback.queue.get_nowait() +# assert token == "" + +# @pytest.mark.asyncio +# async def test_node_name_mismatch(callback): +# """Test node name mismatch error""" +# await callback.start_node("node1") +# with pytest.raises(ValueError, match="Node name mismatch"): +# callback("test token", node_name="node2") + +# @pytest.mark.asyncio +# async def test_token_streaming(callback): +# """Test basic token streaming""" +# await callback.start_node("test_node") +# test_tokens = ["Hello", " ", "World", "!"] + +# for token in test_tokens: +# await callback.acall(token, node_name="test_node") + +# # Skip the start node token +# _ = callback.queue.get_nowait() + +# # Check each streamed token +# for expected_token in test_tokens: +# token = callback.queue.get_nowait() +# assert token == expected_token + +# @pytest.mark.asyncio +# async def test_aiter_streaming(callback): +# """Test async iteration over tokens""" +# test_tokens = ["Hello", " ", "World", "!"] + +# await callback.start_node("test_node") +# for token in test_tokens: +# await callback.acall(token, node_name="test_node") +# await callback.end_node("test_node") +# await callback.close() + +# received_tokens = [] +# async for token in callback.aiter(): +# received_tokens.append(token) + +# assert len(received_tokens) == len(test_tokens) + 3 # +3 for start, end, and END tokens +# assert received_tokens[0] == "" +# assert received_tokens[-2] == "" +# assert received_tokens[-1] == "" +# assert received_tokens[1:-2] == test_tokens + +# @pytest.mark.asyncio +# async def test_inactive_node(callback): +# """Test behavior when node is inactive""" +# await callback.start_node("test_node", active=False) +# await callback.acall("This shouldn't be queued", node_name="test_node") +# await callback.end_node("test_node") + +# with pytest.raises(asyncio.QueueEmpty): +# callback.queue.get_nowait() + +# @pytest.mark.asyncio +# async def test_build_special_token(callback): +# """Test building special tokens with parameters""" +# token = await callback._build_special_token( +# "test", +# params={"key": "value", "number": 42} +# ) +# assert token == "" + +# # Test with no params +# token = await callback._build_special_token("test") +# assert token == "" + +# @pytest.mark.asyncio +# async def test_close(callback): +# """Test closing the callback""" +# await callback.close() +# token = callback.queue.get_nowait() +# assert token == "" + +# @pytest.mark.asyncio +# async def test_sequential_nodes(callback): +# """Test handling multiple sequential nodes""" +# # First node +# await callback.start_node("node1") +# await callback.acall("token1", node_name="node1") +# await callback.end_node("node1") + +# # Second node +# await callback.start_node("node2") +# await callback.acall("token2", node_name="node2") +# await callback.end_node("node2") + +# expected_sequence = [ +# "", +# "token1", +# "", +# "", +# "token2", +# "" +# ] + +# for expected in expected_sequence: +# token = callback.queue.get_nowait() +# assert token == expected