From 61cbf41ef8e6e186ea6b76f5dd9ef3a05cacd146 Mon Sep 17 00:00:00 2001 From: James Kwon <96548424+james03160927@users.noreply.github.com> Date: Tue, 7 Jan 2025 22:30:07 -0500 Subject: [PATCH] Index Comfy Nodes From Latest Version to Algolia Search Index. (#129) * Index comfy nodes from the latest version to Algolia Search Index. * specify algolia field --------- Co-authored-by: James Kwon <96548424+hongil0316@users.noreply.github.com> --- entity/algolia.go | 79 ++++++++++ gateways/algolia/algolia.go | 39 ++--- gateways/algolia/algolia_test.go | 79 +++++++++- .../registry_integration_test.go | 48 ++++-- mapper/algolia.go | 53 +++++++ mock/gateways/mock_algolia_service.go | 3 + services/registry/registry_svc.go | 141 +++++++++++------- 7 files changed, 339 insertions(+), 103 deletions(-) create mode 100644 entity/algolia.go create mode 100644 mapper/algolia.go diff --git a/entity/algolia.go b/entity/algolia.go new file mode 100644 index 0000000..b570847 --- /dev/null +++ b/entity/algolia.go @@ -0,0 +1,79 @@ +package entity + +import ( + "registry-backend/ent" + "registry-backend/ent/schema" + "time" +) + +type AlgoliaNode struct { + ObjectID string `json:"objectID"` + + ID string `json:"id,omitempty"` + CreateTime time.Time `json:"create_time,omitempty"` + UpdateTime time.Time `json:"update_time,omitempty"` + PublisherID string `json:"publisher_id,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Category string `json:"category,omitempty"` + Author string `json:"author,omitempty"` + License string `json:"license,omitempty"` + RepositoryURL string `json:"repository_url,omitempty"` + IconURL string `json:"icon_url,omitempty"` + Tags []string `json:"tags,omitempty"` + TotalInstall int64 `json:"total_install,omitempty"` + TotalStar int64 `json:"total_star,omitempty"` + TotalReview int64 `json:"total_review,omitempty"` + Status schema.NodeStatus `json:"status,omitempty"` + StatusDetail string `json:"status_detail,omitempty"` + + LatestVersion string `json:"latest_version,omitempty"` + LatestVersionStatus schema.NodeVersionStatus `json:"latest_version_status,omitempty"` + ComfyNodeNames []string `json:"comfy_nodes,omitempty"` +} + +func (n *AlgoliaNode) ToEntNode() *ent.Node { + node := &ent.Node{ + ID: n.ID, + CreateTime: n.CreateTime, + UpdateTime: n.UpdateTime, + PublisherID: n.PublisherID, + Name: n.Name, + Description: n.Description, + Category: n.Category, + Author: n.Author, + License: n.License, + RepositoryURL: n.RepositoryURL, + IconURL: n.IconURL, + Tags: n.Tags, + TotalInstall: n.TotalInstall, + TotalStar: n.TotalStar, + TotalReview: n.TotalReview, + Status: n.Status, + StatusDetail: n.StatusDetail, + } + if n.LatestVersion == "" { + return node + } + + node.Edges = ent.NodeEdges{ + Versions: []*ent.NodeVersion{{ + NodeID: n.ID, + Version: n.LatestVersion, + Status: n.LatestVersionStatus, + }}, + } + if len(n.ComfyNodeNames) == 0 { + return node + } + + node.Edges.Versions[0].Edges = ent.NodeVersionEdges{ + ComfyNodes: make([]*ent.ComfyNode, 0, len(n.ComfyNodeNames)), + } + for _, name := range n.ComfyNodeNames { + node.Edges.Versions[0].Edges.ComfyNodes = append(node.Edges.Versions[0].Edges.ComfyNodes, &ent.ComfyNode{ + ID: name, + }) + } + return node +} diff --git a/gateways/algolia/algolia.go b/gateways/algolia/algolia.go index 28597c7..3c1e48a 100644 --- a/gateways/algolia/algolia.go +++ b/gateways/algolia/algolia.go @@ -3,10 +3,13 @@ package algolia import ( "context" "fmt" - "github.com/algolia/algoliasearch-client-go/v3/algolia/search" - "github.com/rs/zerolog/log" "registry-backend/config" // assuming a config package exists to hold config values "registry-backend/ent" + "registry-backend/entity" + "registry-backend/mapper" + + "github.com/algolia/algoliasearch-client-go/v3/algolia/search" + "github.com/rs/zerolog/log" ) // AlgoliaService defines the interface for interacting with Algolia search. @@ -47,27 +50,10 @@ func NewAlgoliaService(cfg *config.Config) (AlgoliaService, error) { // IndexNodes indexes the provided nodes in Algolia. func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { index := a.client.InitIndex("nodes_index") - objects := make([]map[string]interface{}, len(nodes)) + objects := make([]entity.AlgoliaNode, len(nodes)) for i, n := range nodes { - o := map[string]interface{}{ - "objectID": n.ID, - "name": n.Name, - "publisher_id": n.PublisherID, - "description": n.Description, - "id": n.ID, - "create_time": n.CreateTime, - "update_time": n.UpdateTime, - "license": n.License, - "repository_url": n.RepositoryURL, - "total_install": n.TotalInstall, - "status": n.Status, - "author": n.Author, - "category": n.Category, - "total_star": n.TotalStar, - "total_review": n.TotalReview, - } - objects[i] = o + objects[i] = mapper.AlgoliaNodeFromEntNode(n) } res, err := index.SaveObjects(objects) @@ -79,18 +65,21 @@ func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { } // SearchNodes searches for nodes in Algolia matching the query. -func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) { +func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) (nodes []*ent.Node, err error) { index := a.client.InitIndex("nodes_index") res, err := index.Search(query, opts...) if err != nil { return nil, fmt.Errorf("failed to search nodes: %w", err) } - var nodes []*ent.Node - if err := res.UnmarshalHits(&nodes); err != nil { + var algoliaNodes []entity.AlgoliaNode + if err := res.UnmarshalHits(&algoliaNodes); err != nil { return nil, fmt.Errorf("failed to unmarshal search results: %w", err) } - return nodes, nil + for _, n := range algoliaNodes { + nodes = append(nodes, n.ToEntNode()) + } + return } // DeleteNode deletes the specified node from Algolia. diff --git a/gateways/algolia/algolia_test.go b/gateways/algolia/algolia_test.go index 1c09723..28840b7 100644 --- a/gateways/algolia/algolia_test.go +++ b/gateways/algolia/algolia_test.go @@ -32,12 +32,69 @@ func TestIndex(t *testing.T) { t.Run("node", func(t *testing.T) { ctx := context.Background() + id := uuid.New() + version := "v1.0.0-" + uuid.NewString() node := &ent.Node{ - ID: uuid.NewString(), - Name: t.Name() + "-" + uuid.NewString(), - TotalStar: 98, - TotalReview: 20, + ID: id.String(), + CreateTime: time.Time{}, + UpdateTime: time.Time{}, + PublisherID: "id", + Name: t.Name() + "-" + uuid.NewString(), + Description: "desc", + Category: "cat", + Author: "au", + License: "license", + RepositoryURL: "somerepo", + IconURL: "someicon", + Tags: []string{"tags"}, + TotalInstall: 10, + TotalStar: 98, + TotalReview: 20, + Status: "status", + StatusDetail: "status detail", + Edges: ent.NodeEdges{Versions: []*ent.NodeVersion{ + { + ID: id, + NodeID: id.String(), + Version: version, + Changelog: "test", + Status: schema.NodeVersionStatusActive, + StatusReason: "test", + PipDependencies: []string{"test"}, + Edges: ent.NodeVersionEdges{ComfyNodes: []*ent.ComfyNode{ + { + ID: "node1", + NodeVersionID: id, + Category: "test", + Function: "test", + Description: "test", + Deprecated: false, + Experimental: false, + InputTypes: "test", + OutputIsList: []bool{true}, + ReturnNames: []string{"test"}, + ReturnTypes: "test", + CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + ID: "node2", + NodeVersionID: id, + Category: "test", + Function: "test", + Description: "test", + Deprecated: false, + Experimental: false, + InputTypes: "test", + OutputIsList: []bool{true}, + ReturnNames: []string{"test"}, + ReturnTypes: "test", + CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC), + }, + }}, + }, + }}, } + for i := 0; i < 10; i++ { err = algolia.IndexNodes(ctx, node) require.NoError(t, err) @@ -47,6 +104,20 @@ func TestIndex(t *testing.T) { nodes, err := algolia.SearchNodes(ctx, node.Name) require.NoError(t, err) require.Len(t, nodes, 1) + // partial information + node.Edges = ent.NodeEdges{ + Versions: []*ent.NodeVersion{ + { + NodeID: id.String(), + Version: version, + Status: schema.NodeVersionStatusActive, + Edges: ent.NodeVersionEdges{ComfyNodes: []*ent.ComfyNode{ + {ID: "node1"}, + {ID: "node2"}, + }}, + }, + }, + } assert.Equal(t, node, nodes[0]) }) diff --git a/integration-tests/registry_integration_test.go b/integration-tests/registry_integration_test.go index fdb70a7..860c458 100644 --- a/integration-tests/registry_integration_test.go +++ b/integration-tests/registry_integration_test.go @@ -12,6 +12,7 @@ import ( "registry-backend/ent/nodeversion" "registry-backend/ent/schema" drip_logging "registry-backend/logging" + "registry-backend/mapper" "registry-backend/mock/gateways" "registry-backend/server/implementation" drip_authorization "registry-backend/server/middleware/authorization" @@ -810,31 +811,21 @@ func TestRegistryComfyNode(t *testing.T) { // create node version node := randomNode() - nodeVersion := randomNodeVersion(0) signedUrl := "test-url" impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return(signedUrl, nil) impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return(signedUrl, nil) impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil) - _, err = withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ - PublisherId: *pub.Id, - NodeId: *node.Id, - Body: &drip.PublishNodeVersionJSONRequestBody{ - PersonalAccessToken: token, - Node: *node, - NodeVersion: *nodeVersion, - }, - }) - require.NoError(t, err, "should not return error") - // create another node versions - nodeVersionToBeBackfill := []*drip.NodeVersion{ + // create node versions + nodeVersions := []*drip.NodeVersion{ + randomNodeVersion(0), randomNodeVersion(1), randomNodeVersion(2), randomNodeVersion(3), randomNodeVersion(4), randomNodeVersion(5), } - for _, nv := range nodeVersionToBeBackfill { + for _, nv := range nodeVersions { _, err = withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ PublisherId: *pub.Id, NodeId: *node.Id, @@ -846,6 +837,8 @@ func TestRegistryComfyNode(t *testing.T) { }) require.NoError(t, err, "should not return error") } + nodeVersion := nodeVersions[len(nodeVersions)-1] + backfilledNodeVersions := nodeVersions[:len(nodeVersions)-1] t.Run("NoComfyNode", func(t *testing.T) { res, err := withMiddleware(authz, impl.GetNodeVersion)(ctx, drip.GetNodeVersionRequestObject{ @@ -893,6 +886,29 @@ func TestRegistryComfyNode(t *testing.T) { require.NoError(t, err) require.IsType(t, drip.CreateComfyNodes204Response{}, res) + t.Run("AssertAlgolia", func(t *testing.T) { + indexed := impl.mockAlgolia.LastIndexedNodes + require.Len(t, impl.mockAlgolia.LastIndexedNodes, 1) + + node, err := client.Node.Get(ctx, *node.Id) + require.NoError(t, err) + nodeVersion, err := client.NodeVersion.Query().Where(nodeversion.Version(*nodeVersion.Version)).WithComfyNodes().Only(ctx) + require.NoError(t, err) + node.Edges.Versions = append(node.Edges.Versions, nodeVersion) + + assert.Equal(t, node.ID, indexed[0].ID) + assert.Equal(t, node.Edges.Versions[0].ID, indexed[0].Edges.Versions[0].ID) + indexedComfyNodes := drip.CreateComfyNodesJSONRequestBody{ + Nodes: &map[string]drip.ComfyNode{}, + } + for _, node := range indexed[0].Edges.Versions[0].Edges.ComfyNodes { + cn := *(mapper.DBComfyNodeToApiComfyNode(node)) + cn.ComfyNodeId = nil + (*indexedComfyNodes.Nodes)[node.ID] = cn + } + assert.Equal(t, comfyNodes, indexedComfyNodes) + }) + t.Run("GetComfyNodes", func(t *testing.T) { for k, v := range *comfyNodes.Nodes { v.ComfyNodeId = proto.String(k) @@ -983,8 +999,8 @@ func TestRegistryComfyNode(t *testing.T) { res, err := withMiddleware(authz, impl.ComfyNodesBackfill)(ctx, drip.ComfyNodesBackfillRequestObject{}) require.NoError(t, err, "should return created node version") require.IsType(t, drip.ComfyNodesBackfill204Response{}, res) - impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(nodeVersionToBeBackfill)+mockCalled) - mockCalled += len(nodeVersionToBeBackfill) + impl.mockPubsubService.AssertNumberOfCalls(t, "PublishNodePack", len(backfilledNodeVersions)+mockCalled) + mockCalled += len(backfilledNodeVersions) }) t.Run("Limited", func(t *testing.T) { diff --git a/mapper/algolia.go b/mapper/algolia.go new file mode 100644 index 0000000..9d4f0a1 --- /dev/null +++ b/mapper/algolia.go @@ -0,0 +1,53 @@ +package mapper + +import ( + "registry-backend/ent" + "registry-backend/entity" +) + +func AlgoliaNodeFromEntNode(node *ent.Node) entity.AlgoliaNode { + n := entity.AlgoliaNode{ + ObjectID: node.ID, + ID: node.ID, + CreateTime: node.CreateTime, + UpdateTime: node.UpdateTime, + PublisherID: node.PublisherID, + Name: node.Name, + Description: node.Description, + Category: node.Category, + Author: node.Author, + License: node.License, + RepositoryURL: node.RepositoryURL, + IconURL: node.IconURL, + Tags: node.Tags, + TotalInstall: node.TotalInstall, + TotalStar: node.TotalStar, + TotalReview: node.TotalReview, + Status: node.Status, + StatusDetail: node.StatusDetail, + + LatestVersion: "", + LatestVersionStatus: "", + } + + if node.Edges.Versions == nil { + return n + } + + var lv *ent.NodeVersion + for _, v := range node.Edges.Versions { + if lv == nil { + lv = v + } else if v.CreateTime.After(lv.CreateTime) { + lv = v + } + } + n.LatestVersion = lv.Version + n.LatestVersionStatus = lv.Status + n.ComfyNodeNames = make([]string, 0, len(lv.Edges.ComfyNodes)) + for _, v := range lv.Edges.ComfyNodes { + n.ComfyNodeNames = append(n.ComfyNodeNames, v.ID) + } + + return n +} diff --git a/mock/gateways/mock_algolia_service.go b/mock/gateways/mock_algolia_service.go index b4ba792..79e51be 100644 --- a/mock/gateways/mock_algolia_service.go +++ b/mock/gateways/mock_algolia_service.go @@ -12,10 +12,13 @@ var _ algolia.AlgoliaService = &MockAlgoliaService{} type MockAlgoliaService struct { mock.Mock + + LastIndexedNodes []*ent.Node } // IndexNodes implements algolia.AlgoliaService. func (m *MockAlgoliaService) IndexNodes(ctx context.Context, n ...*ent.Node) error { + m.LastIndexedNodes = n args := m.Called(ctx, n) return args.Error(0) } diff --git a/services/registry/registry_svc.go b/services/registry/registry_svc.go index 40a6845..1f01915 100644 --- a/services/registry/registry_svc.go +++ b/services/registry/registry_svc.go @@ -116,21 +116,8 @@ func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, pag } // Fetch nodes with pagination - nodes, err := query. - WithVersions(func(q *ent.NodeVersionQuery) { - q.Modify(func(s *sql.Selector) { - s.Where(sql.ExprP( - `(node_id, version) IN ( - SELECT node_id, MAX(version) - FROM node_versions - GROUP BY node_id - )`, - )) - }) - }). - Offset(offset). - Limit(limit). - All(ctx) + query = s.decorateNodeQueryWithLatestVersion(query).Offset(offset).Limit(limit) + nodes, err := query.All(ctx) if err != nil { return nil, fmt.Errorf("failed to list nodes: %w", err) } @@ -295,25 +282,28 @@ func (s *RegistryService) CreateNode(ctx context.Context, client *ent.Client, pu return createdNode, err } -func (s *RegistryService) UpdateNode(ctx context.Context, client *ent.Client, updateFunc func(client *ent.Client) *ent.NodeUpdateOne) (*ent.Node, error) { - var node *ent.Node +func (s *RegistryService) UpdateNode( + ctx context.Context, + client *ent.Client, + updateFunc func(client *ent.Client) *ent.NodeUpdateOne) (*ent.Node, error) { + var n *ent.Node err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) { update := updateFunc(tx.Client()) log.Ctx(ctx).Info().Msgf("updating node fields: %v", update.Mutation().Fields()) - node, err = update.Save(ctx) + n, err = update.Save(ctx) if err != nil { return fmt.Errorf("failed to update node: %w", err) } - err = s.algolia.IndexNodes(ctx, node) + _, err = s.indexNodeWithLatestVersion(ctx, tx.Client(), n.ID) if err != nil { return fmt.Errorf("failed to index node: %w", err) } return err }) - return node, err + return n, err } func (s *RegistryService) GetNode(ctx context.Context, client *ent.Client, nodeID string) (*ent.Node, error) { @@ -464,7 +454,7 @@ func (s *RegistryService) ListNodeVersions( }, nil } -func (s *RegistryService) AddNodeReview(ctx context.Context, client *ent.Client, nodeId, userID string, star int) (nv *ent.Node, err error) { +func (s *RegistryService) AddNodeReview(ctx context.Context, client *ent.Client, nodeId, userID string, star int) (n *ent.Node, err error) { log.Ctx(ctx).Info().Msgf("add review to node: %v ", nodeId) err = db.WithTx(ctx, client, func(tx *ent.Tx) error { @@ -487,16 +477,11 @@ func (s *RegistryService) AddNodeReview(ctx context.Context, client *ent.Client, return fmt.Errorf("fail to add review: %w", err) } - nv, err = s.GetNode(ctx, tx.Client(), nodeId) - if err != nil { - return fmt.Errorf("fail to fetch node s") - } - - err = s.algolia.IndexNodes(ctx, nv) + n, err = s.indexNodeWithLatestVersion(ctx, tx.Client(), nodeId) if err != nil { return fmt.Errorf("failed to index node: %w", err) } - + n.Edges.Versions = nil return nil }) @@ -540,11 +525,12 @@ func (s *RegistryService) UpdateNodeVersion(ctx context.Context, client *ent.Cli func (s *RegistryService) RecordNodeInstallation(ctx context.Context, client *ent.Client, node *ent.Node) (*ent.Node, error) { var n *ent.Node err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) { - node, err = tx.Node.UpdateOne(node).AddTotalInstall(1).Save(ctx) + n, err = tx.Node.UpdateOne(node).AddTotalInstall(1).Save(ctx) if err != nil { return err } - err = s.algolia.IndexNodes(ctx, node) + + _, err = s.indexNodeWithLatestVersion(ctx, tx.Client(), n.ID) if err != nil { return fmt.Errorf("failed to index node: %w", err) } @@ -586,7 +572,7 @@ var ErrComfyNodesAlreadyExist = errors.New("comfy nodes already exist") func (s *RegistryService) CreateComfyNodes(ctx context.Context, client *ent.Client, nodeID string, nodeVersion string, comfyNodes map[string]drip.ComfyNode) (err error) { return db.WithTx(ctx, client, func(tx *ent.Tx) error { - nv, err := client.NodeVersion.Query(). + nv, err := tx.NodeVersion.Query(). Where(nodeversion.VersionEQ(nodeVersion)). Where(nodeversion.NodeIDEQ(nodeID)). WithComfyNodes(). @@ -602,7 +588,7 @@ func (s *RegistryService) CreateComfyNodes(ctx context.Context, client *ent.Clie comfyNodesCreates := make([]*ent.ComfyNodeCreate, 0, len(comfyNodes)) for k, n := range comfyNodes { - comfyNodeCreate := client.ComfyNode.Create(). + comfyNodeCreate := tx.ComfyNode.Create(). SetID(k). SetNodeVersionID(nv.ID) @@ -635,9 +621,19 @@ func (s *RegistryService) CreateComfyNodes(ctx context.Context, client *ent.Clie } comfyNodesCreates = append(comfyNodesCreates, comfyNodeCreate) } - return client.ComfyNode. + + err = tx.ComfyNode. CreateBulk(comfyNodesCreates...). Exec(ctx) + if err != nil { + return fmt.Errorf("failed to update comfy nodes: %w", err) + } + + if _, err := s.indexNodeWithLatestVersion(ctx, tx.Client(), nodeID); err != nil { + return fmt.Errorf("failed to update node index") + } + + return nil }) } @@ -889,7 +885,10 @@ func (s *RegistryService) BanPublisher(ctx context.Context, client *ent.Client, return fmt.Errorf("fail to update users: %w", err) } - nodes, err := tx.Node.Query().Where(node.PublisherID(id)).All(ctx) + nodes, err := s.decorateNodeQueryWithLatestVersion( + tx.Node.Query(). + Where(node.PublisherID(id)), + ).All(ctx) if len(nodes) == 0 || ent.IsNotFound(err) { return nil } @@ -912,10 +911,13 @@ func (s *RegistryService) BanNode(ctx context.Context, client *ent.Client, publi log.Ctx(ctx).Info().Msgf("banning publisher node: %v %v", publisherid, id) return db.WithTx(ctx, client, func(tx *ent.Tx) error { - n, err := tx.Node.Query().Where(node.And( - node.IDEQ(id), - node.PublisherIDEQ(publisherid), - )).Only(ctx) + n, err := s.decorateNodeQueryWithLatestVersion( + tx.Node.Query(). + Where(node.And( + node.IDEQ(id), + node.PublisherIDEQ(publisherid), + ))). + Only(ctx) if ent.IsNotFound(err) { return nil } @@ -970,42 +972,65 @@ func (s *RegistryService) AssertPublisherBanned(ctx context.Context, client *ent func (s *RegistryService) ReindexAllNodes(ctx context.Context, client *ent.Client) error { log.Ctx(ctx).Info().Msgf("reindexing nodes") - nodes, err := client.Node.Query(). - WithVersions(func(q *ent.NodeVersionQuery) { - q.Modify(func(s *sql.Selector) { - s.Where(sql.ExprP( - `(node_id, create_time) IN ( - SELECT node_id, MAX(create_time) - FROM node_versions - GROUP BY node_id - )`, - )) - }) - }).All(ctx) + nodes, err := s.decorateNodeQueryWithLatestVersion(client.Node.Query()).All(ctx) if err != nil { return fmt.Errorf("failed to fetch all nodes: %w", err) } - nvs := []*ent.NodeVersion{} - for _, node := range nodes { - nvs = append(nvs, node.Edges.Versions...) - } - log.Ctx(ctx).Info().Msgf("reindexing %d number of nodes", len(nodes)) err = s.algolia.IndexNodes(ctx, nodes...) if err != nil { return fmt.Errorf("failed to reindex all nodes: %w", err) } - log.Ctx(ctx).Info().Msgf("reindexing %d number of node versions", len(nvs)) + var nvs []*ent.NodeVersion + for _, n := range nodes { + nvs = append(nvs, n.Edges.Versions...) + } + + log.Ctx(ctx).Info().Msgf("reindexing %d number of n versions", len(nvs)) err = s.algolia.IndexNodeVersions(ctx, nvs...) if err != nil { - return fmt.Errorf("failed to reindex all node versions: %w", err) + return fmt.Errorf("failed to reindex all n versions: %w", err) } return nil } -func (s *RegistryService) PerformSecurityCheck(ctx context.Context, client *ent.Client, nodeVersion *ent.NodeVersion) error { +// indexNodeWithLatestVersion re-indexes a single node and its latest version +func (s *RegistryService) indexNodeWithLatestVersion( + ctx context.Context, + client *ent.Client, + nodeID string) (*ent.Node, error) { + n, err := s.decorateNodeQueryWithLatestVersion( + client.Node.Query(). + Where(node.IDEQ(nodeID)), + ).Only(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query node: %w", err) + } + if err := s.algolia.IndexNodes(ctx, n); err != nil { + return nil, fmt.Errorf("failed to update node: %w", err) + } + return n, nil +} + +func (s *RegistryService) decorateNodeQueryWithLatestVersion(q *ent.NodeQuery) *ent.NodeQuery { + return q.WithVersions(func(q *ent.NodeVersionQuery) { + q.WithComfyNodes(). + Modify(func(s *sql.Selector) { + s.Where(sql.ExprP( + `(node_id, create_time) IN ( + SELECT node_id, MAX(create_time) + FROM node_versions + GROUP BY node_id + )`, + )) + }) + }) +} + +func (s *RegistryService) PerformSecurityCheck( + ctx context.Context, client *ent.Client, nodeVersion *ent.NodeVersion) error { log.Ctx(ctx).Info().Msgf("Scanning node %s@%s w/ version ID: %s", nodeVersion.NodeID, nodeVersion.Version, nodeVersion.ID)