Skip to content

Commit

Permalink
Index Comfy Nodes From Latest Version to Algolia Search Index. (#122)
Browse files Browse the repository at this point in the history
Co-authored-by: James Kwon <[email protected]>
  • Loading branch information
james03160927 and james03160927 authored Jan 6, 2025
1 parent 76ee821 commit 5753550
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 103 deletions.
24 changes: 24 additions & 0 deletions entity/algolia.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package entity

import "registry-backend/ent"

type AlgoliaNode struct {
ObjectID string `json:"objectID"`
*ent.Node
LatestVersion *struct {
*ent.NodeVersion
ComfyNodes map[string]*ent.ComfyNode `json:"comfy_nodes"`
} `json:"latest_version"`
}

func (n *AlgoliaNode) ToEntNode() *ent.Node {
node := n.Node
if n.LatestVersion != nil {
nv := n.LatestVersion.NodeVersion
for _, v := range n.LatestVersion.ComfyNodes {
nv.Edges.ComfyNodes = append(nv.Edges.ComfyNodes, v)
}
node.Edges.Versions = []*ent.NodeVersion{nv}
}
return node
}
39 changes: 14 additions & 25 deletions gateways/algolia/algolia.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package algolia
import (
"context"
"fmt"
"github.com/algolia/algoliasearch-client-go/v3/algolia/search"
"github.com/rs/zerolog/log"
"registry-backend/config" // assuming a config package exists to hold config values
"registry-backend/ent"
"registry-backend/entity"
"registry-backend/mapper"

"github.com/algolia/algoliasearch-client-go/v3/algolia/search"
"github.com/rs/zerolog/log"
)

// AlgoliaService defines the interface for interacting with Algolia search.
Expand Down Expand Up @@ -47,27 +50,10 @@ func NewAlgoliaService(cfg *config.Config) (AlgoliaService, error) {
// IndexNodes indexes the provided nodes in Algolia.
func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
index := a.client.InitIndex("nodes_index")
objects := make([]map[string]interface{}, len(nodes))
objects := make([]entity.AlgoliaNode, len(nodes))

for i, n := range nodes {
o := map[string]interface{}{
"objectID": n.ID,
"name": n.Name,
"publisher_id": n.PublisherID,
"description": n.Description,
"id": n.ID,
"create_time": n.CreateTime,
"update_time": n.UpdateTime,
"license": n.License,
"repository_url": n.RepositoryURL,
"total_install": n.TotalInstall,
"status": n.Status,
"author": n.Author,
"category": n.Category,
"total_star": n.TotalStar,
"total_review": n.TotalReview,
}
objects[i] = o
objects[i] = mapper.AlgoliaNodeFromEntNode(n)
}

res, err := index.SaveObjects(objects)
Expand All @@ -79,18 +65,21 @@ func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
}

// SearchNodes searches for nodes in Algolia matching the query.
func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) {
func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) (nodes []*ent.Node, err error) {
index := a.client.InitIndex("nodes_index")
res, err := index.Search(query, opts...)
if err != nil {
return nil, fmt.Errorf("failed to search nodes: %w", err)
}

var nodes []*ent.Node
if err := res.UnmarshalHits(&nodes); err != nil {
var algoliaNodes []entity.AlgoliaNode
if err := res.UnmarshalHits(&algoliaNodes); err != nil {
return nil, fmt.Errorf("failed to unmarshal search results: %w", err)
}
return nodes, nil
for _, n := range algoliaNodes {
nodes = append(nodes, n.ToEntNode())
}
return
}

// DeleteNode deletes the specified node from Algolia.
Expand Down
69 changes: 65 additions & 4 deletions gateways/algolia/algolia_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"registry-backend/config"
"registry-backend/ent"
"registry-backend/ent/schema"
"sort"
"testing"
"time"

Expand All @@ -32,12 +33,68 @@ func TestIndex(t *testing.T) {

t.Run("node", func(t *testing.T) {
ctx := context.Background()
id := uuid.New()
node := &ent.Node{
ID: uuid.NewString(),
Name: t.Name() + "-" + uuid.NewString(),
TotalStar: 98,
TotalReview: 20,
ID: id.String(),
CreateTime: time.Time{},
UpdateTime: time.Time{},
PublisherID: "id",
Name: t.Name() + "-" + uuid.NewString(),
Description: "desc",
Category: "cat",
Author: "au",
License: "license",
RepositoryURL: "somerepo",
IconURL: "someicon",
Tags: []string{"tags"},
TotalInstall: 10,
TotalStar: 98,
TotalReview: 20,
Status: "status",
StatusDetail: "status detail",
Edges: ent.NodeEdges{Versions: []*ent.NodeVersion{
{
ID: id,
NodeID: id.String(),
Version: "v1.0.0-" + uuid.NewString(),
Changelog: "test",
Status: schema.NodeVersionStatusActive,
StatusReason: "test",
PipDependencies: []string{"test"},
Edges: ent.NodeVersionEdges{ComfyNodes: []*ent.ComfyNode{
{
ID: "node1",
NodeVersionID: id,
Category: "test",
Function: "test",
Description: "test",
Deprecated: false,
Experimental: false,
InputTypes: "test",
OutputIsList: []bool{true},
ReturnNames: []string{"test"},
ReturnTypes: "test",
CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC),
},
{
ID: "node2",
NodeVersionID: id,
Category: "test",
Function: "test",
Description: "test",
Deprecated: false,
Experimental: false,
InputTypes: "test",
OutputIsList: []bool{true},
ReturnNames: []string{"test"},
ReturnTypes: "test",
CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC),
},
}},
},
}},
}

for i := 0; i < 10; i++ {
err = algolia.IndexNodes(ctx, node)
require.NoError(t, err)
Expand All @@ -47,6 +104,10 @@ func TestIndex(t *testing.T) {
nodes, err := algolia.SearchNodes(ctx, node.Name)
require.NoError(t, err)
require.Len(t, nodes, 1)
// sometimes the order is mixed and assertion fail
sort.Slice(nodes[0].Edges.Versions[0].Edges.ComfyNodes, func(i, j int) bool {
return nodes[0].Edges.Versions[0].Edges.ComfyNodes[i].ID < nodes[0].Edges.Versions[0].Edges.ComfyNodes[j].ID
})
assert.Equal(t, node, nodes[0])
})

Expand Down
48 changes: 32 additions & 16 deletions integration-tests/registry_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"registry-backend/ent/nodeversion"
"registry-backend/ent/schema"
drip_logging "registry-backend/logging"
"registry-backend/mapper"
"registry-backend/mock/gateways"
"registry-backend/server/implementation"
drip_authorization "registry-backend/server/middleware/authorization"
Expand Down Expand Up @@ -882,31 +883,21 @@ func TestRegistryComfyNode(t *testing.T) {

// create node version
node := randomNode()
nodeVersion := randomNodeVersion(0)
signedUrl := "test-url"
impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return(signedUrl, nil)
impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return(signedUrl, nil)
impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil)
_, err = withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{
PublisherId: *pub.Id,
NodeId: *node.Id,
Body: &drip.PublishNodeVersionJSONRequestBody{
PersonalAccessToken: token,
Node: *node,
NodeVersion: *nodeVersion,
},
})
require.NoError(t, err, "should not return error")

// create another node versions
nodeVersionToBeBackfill := []*drip.NodeVersion{
// create node versions
nodeVersions := []*drip.NodeVersion{
randomNodeVersion(0),
randomNodeVersion(1),
randomNodeVersion(2),
randomNodeVersion(3),
randomNodeVersion(4),
randomNodeVersion(5),
}
for _, nv := range nodeVersionToBeBackfill {
for _, nv := range nodeVersions {
_, err = withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{
PublisherId: *pub.Id,
NodeId: *node.Id,
Expand All @@ -918,6 +909,8 @@ func TestRegistryComfyNode(t *testing.T) {
})
require.NoError(t, err, "should not return error")
}
nodeVersion := nodeVersions[len(nodeVersions)-1]
backfilledNodeVersions := nodeVersions[:len(nodeVersions)-1]

t.Run("NoComfyNode", func(t *testing.T) {
res, err := withMiddleware(authz, impl.GetNodeVersion)(ctx, drip.GetNodeVersionRequestObject{
Expand Down Expand Up @@ -965,6 +958,29 @@ func TestRegistryComfyNode(t *testing.T) {
require.NoError(t, err)
require.IsType(t, drip.CreateComfyNodes204Response{}, res)

t.Run("AssertAlgolia", func(t *testing.T) {
indexed := impl.mockAlgolia.LastIndexedNodes
require.Len(t, impl.mockAlgolia.LastIndexedNodes, 1)

node, err := client.Node.Get(ctx, *node.Id)
require.NoError(t, err)
nodeVersion, err := client.NodeVersion.Query().Where(nodeversion.Version(*nodeVersion.Version)).WithComfyNodes().Only(ctx)
require.NoError(t, err)
node.Edges.Versions = append(node.Edges.Versions, nodeVersion)

assert.Equal(t, node.ID, indexed[0].ID)
assert.Equal(t, node.Edges.Versions[0].ID, indexed[0].Edges.Versions[0].ID)
indexedComfyNodes := drip.CreateComfyNodesJSONRequestBody{
Nodes: &map[string]drip.ComfyNode{},
}
for _, node := range indexed[0].Edges.Versions[0].Edges.ComfyNodes {
cn := *(mapper.DBComfyNodeToApiComfyNode(node))
cn.ComfyNodeId = nil
(*indexedComfyNodes.Nodes)[node.ID] = cn
}
assert.Equal(t, comfyNodes, indexedComfyNodes)
})

t.Run("GetComfyNodes", func(t *testing.T) {
for k, v := range *comfyNodes.Nodes {
v.ComfyNodeId = proto.String(k)
Expand Down Expand Up @@ -1055,8 +1071,8 @@ func TestRegistryComfyNode(t *testing.T) {
res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{})
require.NoError(t, err, "should return created node version")
require.IsType(t, drip.ComfyNodesBackfill204Response{}, res)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(nodeVersionToBeBackfill)+mockCalled)
mockCalled += len(nodeVersionToBeBackfill)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(backfilledNodeVersions)+mockCalled)
mockCalled += len(backfilledNodeVersions)
})

t.Run("Limited", func(t *testing.T) {
Expand Down
42 changes: 42 additions & 0 deletions mapper/algolia.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package mapper

import (
"registry-backend/ent"
"registry-backend/entity"
)

func AlgoliaNodeFromEntNode(node *ent.Node) entity.AlgoliaNode {
n := entity.AlgoliaNode{
ObjectID: node.ID,
Node: new(ent.Node),
}
*n.Node = *node
n.Edges = ent.NodeEdges{}
if node.Edges.Versions == nil {
return n
}

var lv *ent.NodeVersion
for _, v := range node.Edges.Versions {
if lv == nil {
lv = v
} else if v.CreateTime.After(lv.CreateTime) {
lv = v
}
}

n.LatestVersion = &struct {
*ent.NodeVersion
ComfyNodes map[string]*ent.ComfyNode `json:"comfy_nodes"`
}{
NodeVersion: new(ent.NodeVersion),
ComfyNodes: make(map[string]*ent.ComfyNode, len(lv.Edges.ComfyNodes)),
}
*n.LatestVersion.NodeVersion = *lv
n.LatestVersion.NodeVersion.Edges = ent.NodeVersionEdges{}
for _, v := range lv.Edges.ComfyNodes {
n.LatestVersion.ComfyNodes[v.ID] = v
}

return n
}
3 changes: 3 additions & 0 deletions mock/gateways/mock_algolia_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ var _ algolia.AlgoliaService = &MockAlgoliaService{}

type MockAlgoliaService struct {
mock.Mock

LastIndexedNodes []*ent.Node
}

// IndexNodes implements algolia.AlgoliaService.
func (m *MockAlgoliaService) IndexNodes(ctx context.Context, n ...*ent.Node) error {
m.LastIndexedNodes = n
args := m.Called(ctx, n)
return args.Error(0)
}
Expand Down
Loading

0 comments on commit 5753550

Please sign in to comment.