Skip to content

Commit

Permalink
Index Comfy Nodes From Latest Version to Algolia Search Index. (#129)
Browse files Browse the repository at this point in the history
* Index comfy nodes from the latest version to Algolia Search Index.

* specify algolia field

---------

Co-authored-by: James Kwon <[email protected]>
  • Loading branch information
james03160927 and james03160927 authored Jan 8, 2025
1 parent 658065f commit 61cbf41
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 103 deletions.
79 changes: 79 additions & 0 deletions entity/algolia.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package entity

import (
"registry-backend/ent"
"registry-backend/ent/schema"
"time"
)

type AlgoliaNode struct {
ObjectID string `json:"objectID"`

ID string `json:"id,omitempty"`
CreateTime time.Time `json:"create_time,omitempty"`
UpdateTime time.Time `json:"update_time,omitempty"`
PublisherID string `json:"publisher_id,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Category string `json:"category,omitempty"`
Author string `json:"author,omitempty"`
License string `json:"license,omitempty"`
RepositoryURL string `json:"repository_url,omitempty"`
IconURL string `json:"icon_url,omitempty"`
Tags []string `json:"tags,omitempty"`
TotalInstall int64 `json:"total_install,omitempty"`
TotalStar int64 `json:"total_star,omitempty"`
TotalReview int64 `json:"total_review,omitempty"`
Status schema.NodeStatus `json:"status,omitempty"`
StatusDetail string `json:"status_detail,omitempty"`

LatestVersion string `json:"latest_version,omitempty"`
LatestVersionStatus schema.NodeVersionStatus `json:"latest_version_status,omitempty"`
ComfyNodeNames []string `json:"comfy_nodes,omitempty"`
}

func (n *AlgoliaNode) ToEntNode() *ent.Node {
node := &ent.Node{
ID: n.ID,
CreateTime: n.CreateTime,
UpdateTime: n.UpdateTime,
PublisherID: n.PublisherID,
Name: n.Name,
Description: n.Description,
Category: n.Category,
Author: n.Author,
License: n.License,
RepositoryURL: n.RepositoryURL,
IconURL: n.IconURL,
Tags: n.Tags,
TotalInstall: n.TotalInstall,
TotalStar: n.TotalStar,
TotalReview: n.TotalReview,
Status: n.Status,
StatusDetail: n.StatusDetail,
}
if n.LatestVersion == "" {
return node
}

node.Edges = ent.NodeEdges{
Versions: []*ent.NodeVersion{{
NodeID: n.ID,
Version: n.LatestVersion,
Status: n.LatestVersionStatus,
}},
}
if len(n.ComfyNodeNames) == 0 {
return node
}

node.Edges.Versions[0].Edges = ent.NodeVersionEdges{
ComfyNodes: make([]*ent.ComfyNode, 0, len(n.ComfyNodeNames)),
}
for _, name := range n.ComfyNodeNames {
node.Edges.Versions[0].Edges.ComfyNodes = append(node.Edges.Versions[0].Edges.ComfyNodes, &ent.ComfyNode{
ID: name,
})
}
return node
}
39 changes: 14 additions & 25 deletions gateways/algolia/algolia.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package algolia
import (
"context"
"fmt"
"github.com/algolia/algoliasearch-client-go/v3/algolia/search"
"github.com/rs/zerolog/log"
"registry-backend/config" // assuming a config package exists to hold config values
"registry-backend/ent"
"registry-backend/entity"
"registry-backend/mapper"

"github.com/algolia/algoliasearch-client-go/v3/algolia/search"
"github.com/rs/zerolog/log"
)

// AlgoliaService defines the interface for interacting with Algolia search.
Expand Down Expand Up @@ -47,27 +50,10 @@ func NewAlgoliaService(cfg *config.Config) (AlgoliaService, error) {
// IndexNodes indexes the provided nodes in Algolia.
func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
index := a.client.InitIndex("nodes_index")
objects := make([]map[string]interface{}, len(nodes))
objects := make([]entity.AlgoliaNode, len(nodes))

for i, n := range nodes {
o := map[string]interface{}{
"objectID": n.ID,
"name": n.Name,
"publisher_id": n.PublisherID,
"description": n.Description,
"id": n.ID,
"create_time": n.CreateTime,
"update_time": n.UpdateTime,
"license": n.License,
"repository_url": n.RepositoryURL,
"total_install": n.TotalInstall,
"status": n.Status,
"author": n.Author,
"category": n.Category,
"total_star": n.TotalStar,
"total_review": n.TotalReview,
}
objects[i] = o
objects[i] = mapper.AlgoliaNodeFromEntNode(n)
}

res, err := index.SaveObjects(objects)
Expand All @@ -79,18 +65,21 @@ func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
}

// SearchNodes searches for nodes in Algolia matching the query.
func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) {
func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) (nodes []*ent.Node, err error) {
index := a.client.InitIndex("nodes_index")
res, err := index.Search(query, opts...)
if err != nil {
return nil, fmt.Errorf("failed to search nodes: %w", err)
}

var nodes []*ent.Node
if err := res.UnmarshalHits(&nodes); err != nil {
var algoliaNodes []entity.AlgoliaNode
if err := res.UnmarshalHits(&algoliaNodes); err != nil {
return nil, fmt.Errorf("failed to unmarshal search results: %w", err)
}
return nodes, nil
for _, n := range algoliaNodes {
nodes = append(nodes, n.ToEntNode())
}
return
}

// DeleteNode deletes the specified node from Algolia.
Expand Down
79 changes: 75 additions & 4 deletions gateways/algolia/algolia_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,69 @@ func TestIndex(t *testing.T) {

t.Run("node", func(t *testing.T) {
ctx := context.Background()
id := uuid.New()
version := "v1.0.0-" + uuid.NewString()
node := &ent.Node{
ID: uuid.NewString(),
Name: t.Name() + "-" + uuid.NewString(),
TotalStar: 98,
TotalReview: 20,
ID: id.String(),
CreateTime: time.Time{},
UpdateTime: time.Time{},
PublisherID: "id",
Name: t.Name() + "-" + uuid.NewString(),
Description: "desc",
Category: "cat",
Author: "au",
License: "license",
RepositoryURL: "somerepo",
IconURL: "someicon",
Tags: []string{"tags"},
TotalInstall: 10,
TotalStar: 98,
TotalReview: 20,
Status: "status",
StatusDetail: "status detail",
Edges: ent.NodeEdges{Versions: []*ent.NodeVersion{
{
ID: id,
NodeID: id.String(),
Version: version,
Changelog: "test",
Status: schema.NodeVersionStatusActive,
StatusReason: "test",
PipDependencies: []string{"test"},
Edges: ent.NodeVersionEdges{ComfyNodes: []*ent.ComfyNode{
{
ID: "node1",
NodeVersionID: id,
Category: "test",
Function: "test",
Description: "test",
Deprecated: false,
Experimental: false,
InputTypes: "test",
OutputIsList: []bool{true},
ReturnNames: []string{"test"},
ReturnTypes: "test",
CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC),
},
{
ID: "node2",
NodeVersionID: id,
Category: "test",
Function: "test",
Description: "test",
Deprecated: false,
Experimental: false,
InputTypes: "test",
OutputIsList: []bool{true},
ReturnNames: []string{"test"},
ReturnTypes: "test",
CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC),
},
}},
},
}},
}

for i := 0; i < 10; i++ {
err = algolia.IndexNodes(ctx, node)
require.NoError(t, err)
Expand All @@ -47,6 +104,20 @@ func TestIndex(t *testing.T) {
nodes, err := algolia.SearchNodes(ctx, node.Name)
require.NoError(t, err)
require.Len(t, nodes, 1)
// partial information
node.Edges = ent.NodeEdges{
Versions: []*ent.NodeVersion{
{
NodeID: id.String(),
Version: version,
Status: schema.NodeVersionStatusActive,
Edges: ent.NodeVersionEdges{ComfyNodes: []*ent.ComfyNode{
{ID: "node1"},
{ID: "node2"},
}},
},
},
}
assert.Equal(t, node, nodes[0])
})

Expand Down
48 changes: 32 additions & 16 deletions integration-tests/registry_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"registry-backend/ent/nodeversion"
"registry-backend/ent/schema"
drip_logging "registry-backend/logging"
"registry-backend/mapper"
"registry-backend/mock/gateways"
"registry-backend/server/implementation"
drip_authorization "registry-backend/server/middleware/authorization"
Expand Down Expand Up @@ -810,31 +811,21 @@ func TestRegistryComfyNode(t *testing.T) {

// create node version
node := randomNode()
nodeVersion := randomNodeVersion(0)
signedUrl := "test-url"
impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return(signedUrl, nil)
impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return(signedUrl, nil)
impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil)
_, err = withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{
PublisherId: *pub.Id,
NodeId: *node.Id,
Body: &drip.PublishNodeVersionJSONRequestBody{
PersonalAccessToken: token,
Node: *node,
NodeVersion: *nodeVersion,
},
})
require.NoError(t, err, "should not return error")

// create another node versions
nodeVersionToBeBackfill := []*drip.NodeVersion{
// create node versions
nodeVersions := []*drip.NodeVersion{
randomNodeVersion(0),
randomNodeVersion(1),
randomNodeVersion(2),
randomNodeVersion(3),
randomNodeVersion(4),
randomNodeVersion(5),
}
for _, nv := range nodeVersionToBeBackfill {
for _, nv := range nodeVersions {
_, err = withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{
PublisherId: *pub.Id,
NodeId: *node.Id,
Expand All @@ -846,6 +837,8 @@ func TestRegistryComfyNode(t *testing.T) {
})
require.NoError(t, err, "should not return error")
}
nodeVersion := nodeVersions[len(nodeVersions)-1]
backfilledNodeVersions := nodeVersions[:len(nodeVersions)-1]

t.Run("NoComfyNode", func(t *testing.T) {
res, err := withMiddleware(authz, impl.GetNodeVersion)(ctx, drip.GetNodeVersionRequestObject{
Expand Down Expand Up @@ -893,6 +886,29 @@ func TestRegistryComfyNode(t *testing.T) {
require.NoError(t, err)
require.IsType(t, drip.CreateComfyNodes204Response{}, res)

t.Run("AssertAlgolia", func(t *testing.T) {
indexed := impl.mockAlgolia.LastIndexedNodes
require.Len(t, impl.mockAlgolia.LastIndexedNodes, 1)

node, err := client.Node.Get(ctx, *node.Id)
require.NoError(t, err)
nodeVersion, err := client.NodeVersion.Query().Where(nodeversion.Version(*nodeVersion.Version)).WithComfyNodes().Only(ctx)
require.NoError(t, err)
node.Edges.Versions = append(node.Edges.Versions, nodeVersion)

assert.Equal(t, node.ID, indexed[0].ID)
assert.Equal(t, node.Edges.Versions[0].ID, indexed[0].Edges.Versions[0].ID)
indexedComfyNodes := drip.CreateComfyNodesJSONRequestBody{
Nodes: &map[string]drip.ComfyNode{},
}
for _, node := range indexed[0].Edges.Versions[0].Edges.ComfyNodes {
cn := *(mapper.DBComfyNodeToApiComfyNode(node))
cn.ComfyNodeId = nil
(*indexedComfyNodes.Nodes)[node.ID] = cn
}
assert.Equal(t, comfyNodes, indexedComfyNodes)
})

t.Run("GetComfyNodes", func(t *testing.T) {
for k, v := range *comfyNodes.Nodes {
v.ComfyNodeId = proto.String(k)
Expand Down Expand Up @@ -983,8 +999,8 @@ func TestRegistryComfyNode(t *testing.T) {
res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{})
require.NoError(t, err, "should return created node version")
require.IsType(t, drip.ComfyNodesBackfill204Response{}, res)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(nodeVersionToBeBackfill)+mockCalled)
mockCalled += len(nodeVersionToBeBackfill)
impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(backfilledNodeVersions)+mockCalled)
mockCalled += len(backfilledNodeVersions)
})

t.Run("Limited", func(t *testing.T) {
Expand Down
53 changes: 53 additions & 0 deletions mapper/algolia.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package mapper

import (
"registry-backend/ent"
"registry-backend/entity"
)

func AlgoliaNodeFromEntNode(node *ent.Node) entity.AlgoliaNode {
n := entity.AlgoliaNode{
ObjectID: node.ID,
ID: node.ID,
CreateTime: node.CreateTime,
UpdateTime: node.UpdateTime,
PublisherID: node.PublisherID,
Name: node.Name,
Description: node.Description,
Category: node.Category,
Author: node.Author,
License: node.License,
RepositoryURL: node.RepositoryURL,
IconURL: node.IconURL,
Tags: node.Tags,
TotalInstall: node.TotalInstall,
TotalStar: node.TotalStar,
TotalReview: node.TotalReview,
Status: node.Status,
StatusDetail: node.StatusDetail,

LatestVersion: "",
LatestVersionStatus: "",
}

if node.Edges.Versions == nil {
return n
}

var lv *ent.NodeVersion
for _, v := range node.Edges.Versions {
if lv == nil {
lv = v
} else if v.CreateTime.After(lv.CreateTime) {
lv = v
}
}
n.LatestVersion = lv.Version
n.LatestVersionStatus = lv.Status
n.ComfyNodeNames = make([]string, 0, len(lv.Edges.ComfyNodes))
for _, v := range lv.Edges.ComfyNodes {
n.ComfyNodeNames = append(n.ComfyNodeNames, v.ID)
}

return n
}
Loading

0 comments on commit 61cbf41

Please sign in to comment.