Skip to content

Commit

Permalink
store comfy node extraction status
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 committed Jan 8, 2025
1 parent 9993fdb commit a986841
Show file tree
Hide file tree
Showing 9 changed files with 464 additions and 291 deletions.
24 changes: 24 additions & 0 deletions entity/algolia.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package entity

import "registry-backend/ent"

type AlgoliaNode struct {
ObjectID string `json:"objectID"`
*ent.Node
LatestVersion *struct {
*ent.NodeVersion
ComfyNodes map[string]*ent.ComfyNode `json:"comfy_nodes"`
} `json:"latest_version"`
}

func (n *AlgoliaNode) ToEntNode() *ent.Node {
node := n.Node
if n.LatestVersion != nil {
nv := n.LatestVersion.NodeVersion
for _, v := range n.LatestVersion.ComfyNodes {
nv.Edges.ComfyNodes = append(nv.Edges.ComfyNodes, v)
}
node.Edges.Versions = []*ent.NodeVersion{nv}
}
return node
}
39 changes: 14 additions & 25 deletions gateways/algolia/algolia.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package algolia
import (
"context"
"fmt"
"github.com/algolia/algoliasearch-client-go/v3/algolia/search"
"github.com/rs/zerolog/log"
"registry-backend/config" // assuming a config package exists to hold config values
"registry-backend/ent"
"registry-backend/entity"
"registry-backend/mapper"

"github.com/algolia/algoliasearch-client-go/v3/algolia/search"
"github.com/rs/zerolog/log"
)

// AlgoliaService defines the interface for interacting with Algolia search.
Expand Down Expand Up @@ -47,27 +50,10 @@ func NewAlgoliaService(cfg *config.Config) (AlgoliaService, error) {
// IndexNodes indexes the provided nodes in Algolia.
func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
index := a.client.InitIndex("nodes_index")
objects := make([]map[string]interface{}, len(nodes))
objects := make([]entity.AlgoliaNode, len(nodes))

for i, n := range nodes {
o := map[string]interface{}{
"objectID": n.ID,
"name": n.Name,
"publisher_id": n.PublisherID,
"description": n.Description,
"id": n.ID,
"create_time": n.CreateTime,
"update_time": n.UpdateTime,
"license": n.License,
"repository_url": n.RepositoryURL,
"total_install": n.TotalInstall,
"status": n.Status,
"author": n.Author,
"category": n.Category,
"total_star": n.TotalStar,
"total_review": n.TotalReview,
}
objects[i] = o
objects[i] = mapper.AlgoliaNodeFromEntNode(n)
}

res, err := index.SaveObjects(objects)
Expand All @@ -79,18 +65,21 @@ func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
}

// SearchNodes searches for nodes in Algolia matching the query.
func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) {
func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) (nodes []*ent.Node, err error) {
index := a.client.InitIndex("nodes_index")
res, err := index.Search(query, opts...)
if err != nil {
return nil, fmt.Errorf("failed to search nodes: %w", err)
}

var nodes []*ent.Node
if err := res.UnmarshalHits(&nodes); err != nil {
var algoliaNodes []entity.AlgoliaNode
if err := res.UnmarshalHits(&algoliaNodes); err != nil {
return nil, fmt.Errorf("failed to unmarshal search results: %w", err)
}
return nodes, nil
for _, n := range algoliaNodes {
nodes = append(nodes, n.ToEntNode())
}
return
}

// DeleteNode deletes the specified node from Algolia.
Expand Down
69 changes: 65 additions & 4 deletions gateways/algolia/algolia_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"registry-backend/config"
"registry-backend/ent"
"registry-backend/ent/schema"
"sort"
"testing"
"time"

Expand All @@ -32,12 +33,68 @@ func TestIndex(t *testing.T) {

t.Run("node", func(t *testing.T) {
ctx := context.Background()
id := uuid.New()
node := &ent.Node{
ID: uuid.NewString(),
Name: t.Name() + "-" + uuid.NewString(),
TotalStar: 98,
TotalReview: 20,
ID: id.String(),
CreateTime: time.Time{},
UpdateTime: time.Time{},
PublisherID: "id",
Name: t.Name() + "-" + uuid.NewString(),
Description: "desc",
Category: "cat",
Author: "au",
License: "license",
RepositoryURL: "somerepo",
IconURL: "someicon",
Tags: []string{"tags"},
TotalInstall: 10,
TotalStar: 98,
TotalReview: 20,
Status: "status",
StatusDetail: "status detail",
Edges: ent.NodeEdges{Versions: []*ent.NodeVersion{
{
ID: id,
NodeID: id.String(),
Version: "v1.0.0-" + uuid.NewString(),
Changelog: "test",
Status: schema.NodeVersionStatusActive,
StatusReason: "test",
PipDependencies: []string{"test"},
Edges: ent.NodeVersionEdges{ComfyNodes: []*ent.ComfyNode{
{
ID: "node1",
NodeVersionID: id,
Category: "test",
Function: "test",
Description: "test",
Deprecated: false,
Experimental: false,
InputTypes: "test",
OutputIsList: []bool{true},
ReturnNames: []string{"test"},
ReturnTypes: "test",
CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC),
},
{
ID: "node2",
NodeVersionID: id,
Category: "test",
Function: "test",
Description: "test",
Deprecated: false,
Experimental: false,
InputTypes: "test",
OutputIsList: []bool{true},
ReturnNames: []string{"test"},
ReturnTypes: "test",
CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC),
},
}},
},
}},
}

for i := 0; i < 10; i++ {
err = algolia.IndexNodes(ctx, node)
require.NoError(t, err)
Expand All @@ -47,6 +104,10 @@ func TestIndex(t *testing.T) {
nodes, err := algolia.SearchNodes(ctx, node.Name)
require.NoError(t, err)
require.Len(t, nodes, 1)
// sometimes the order is mixed and assertion fail
sort.Slice(nodes[0].Edges.Versions[0].Edges.ComfyNodes, func(i, j int) bool {
return nodes[0].Edges.Versions[0].Edges.ComfyNodes[i].ID < nodes[0].Edges.Versions[0].Edges.ComfyNodes[j].ID
})
assert.Equal(t, node, nodes[0])
})

Expand Down
Loading

0 comments on commit a986841

Please sign in to comment.