From 252cd01d1489ecf43c81ea8d1f44529b399f3d02 Mon Sep 17 00:00:00 2001 From: atanda rasheed Date: Sun, 5 Jan 2025 04:29:44 +0100 Subject: [PATCH] fix(display): handle multi inputs and outputs --- lib/axon/display.ex | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/lib/axon/display.ex b/lib/axon/display.ex index 11eb4f62..057ce4fb 100644 --- a/lib/axon/display.ex +++ b/lib/axon/display.ex @@ -109,7 +109,7 @@ defmodule Axon.Display do %Axon.Node{ id: id, op_name: :container, - parent: [parents], + parent: [_ | _] = parents, name: name_fn }, nodes, @@ -221,6 +221,13 @@ defmodule Axon.Display do "#{type}#{shape}" end + defp render_output_shape(shapes) when is_tuple(shapes) do + shapes + |> Tuple.to_list() + |> Enum.map(&render_output_shape(&1)) + |> Enum.join(", ") + end + defp type_str({type, size}), do: "#{Atom.to_string(type)}#{size}" defp render_options(opts) do @@ -347,7 +354,9 @@ defmodule Axon.Display do name = name_fn.(op, op_counts) node_shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates) - to_node = %{axon: :axon, id: id, op: op, name: name, shape: node_shape} + shape = expand_output_shape(node_shape) + + to_node = %{axon: :axon, id: id, op: op, name: name, shape: shape} new_edgelist = Enum.reduce(node_inputs, edgelist, fn from_node, acc -> @@ -357,6 +366,15 @@ defmodule Axon.Display do {to_node, {cache, op_counts, new_edgelist}} end + defp expand_output_shape(%Nx.Tensor{} = tensor), do: Nx.shape(tensor) + + defp expand_output_shape(shapes) when is_tuple(shapes) do + shapes + |> Tuple.to_list() + |> Enum.map(&expand_output_shape/1) + |> List.to_tuple() + end + defp generate_mermaid_node_entry(%{id: id, op: :input, name: name, shape: shape}) do ~s'#{id}[/"#{name} (:input) #{inspect(shape)}"/]' end