diff --git a/entity/algolia.go b/entity/algolia.go new file mode 100644 index 0000000..4d35767 --- /dev/null +++ b/entity/algolia.go @@ -0,0 +1,24 @@ +package entity + +import "registry-backend/ent" + +type AlgoliaNode struct { + ObjectID string `json:"objectID"` + *ent.Node + LatestVersion *struct { + *ent.NodeVersion + ComfyNodes map[string]*ent.ComfyNode `json:"comfy_nodes"` + } `json:"latest_version"` +} + +func (n *AlgoliaNode) ToEntNode() *ent.Node { + node := n.Node + if n.LatestVersion != nil { + nv := n.LatestVersion.NodeVersion + for _, v := range n.LatestVersion.ComfyNodes { + nv.Edges.ComfyNodes = append(nv.Edges.ComfyNodes, v) + } + node.Edges.Versions = []*ent.NodeVersion{nv} + } + return node +} diff --git a/gateways/algolia/algolia.go b/gateways/algolia/algolia.go index 28597c7..3c1e48a 100644 --- a/gateways/algolia/algolia.go +++ b/gateways/algolia/algolia.go @@ -3,10 +3,13 @@ package algolia import ( "context" "fmt" - "github.com/algolia/algoliasearch-client-go/v3/algolia/search" - "github.com/rs/zerolog/log" "registry-backend/config" // assuming a config package exists to hold config values "registry-backend/ent" + "registry-backend/entity" + "registry-backend/mapper" + + "github.com/algolia/algoliasearch-client-go/v3/algolia/search" + "github.com/rs/zerolog/log" ) // AlgoliaService defines the interface for interacting with Algolia search. @@ -47,27 +50,10 @@ func NewAlgoliaService(cfg *config.Config) (AlgoliaService, error) { // IndexNodes indexes the provided nodes in Algolia. func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { index := a.client.InitIndex("nodes_index") - objects := make([]map[string]interface{}, len(nodes)) + objects := make([]entity.AlgoliaNode, len(nodes)) for i, n := range nodes { - o := map[string]interface{}{ - "objectID": n.ID, - "name": n.Name, - "publisher_id": n.PublisherID, - "description": n.Description, - "id": n.ID, - "create_time": n.CreateTime, - "update_time": n.UpdateTime, - "license": n.License, - "repository_url": n.RepositoryURL, - "total_install": n.TotalInstall, - "status": n.Status, - "author": n.Author, - "category": n.Category, - "total_star": n.TotalStar, - "total_review": n.TotalReview, - } - objects[i] = o + objects[i] = mapper.AlgoliaNodeFromEntNode(n) } res, err := index.SaveObjects(objects) @@ -79,18 +65,21 @@ func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { } // SearchNodes searches for nodes in Algolia matching the query. -func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) { +func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) (nodes []*ent.Node, err error) { index := a.client.InitIndex("nodes_index") res, err := index.Search(query, opts...) if err != nil { return nil, fmt.Errorf("failed to search nodes: %w", err) } - var nodes []*ent.Node - if err := res.UnmarshalHits(&nodes); err != nil { + var algoliaNodes []entity.AlgoliaNode + if err := res.UnmarshalHits(&algoliaNodes); err != nil { return nil, fmt.Errorf("failed to unmarshal search results: %w", err) } - return nodes, nil + for _, n := range algoliaNodes { + nodes = append(nodes, n.ToEntNode()) + } + return } // DeleteNode deletes the specified node from Algolia. diff --git a/gateways/algolia/algolia_test.go b/gateways/algolia/algolia_test.go index 1c09723..c9f01db 100644 --- a/gateways/algolia/algolia_test.go +++ b/gateways/algolia/algolia_test.go @@ -6,6 +6,7 @@ import ( "registry-backend/config" "registry-backend/ent" "registry-backend/ent/schema" + "sort" "testing" "time" @@ -32,12 +33,68 @@ func TestIndex(t *testing.T) { t.Run("node", func(t *testing.T) { ctx := context.Background() + id := uuid.New() node := &ent.Node{ - ID: uuid.NewString(), - Name: t.Name() + "-" + uuid.NewString(), - TotalStar: 98, - TotalReview: 20, + ID: id.String(), + CreateTime: time.Time{}, + UpdateTime: time.Time{}, + PublisherID: "id", + Name: t.Name() + "-" + uuid.NewString(), + Description: "desc", + Category: "cat", + Author: "au", + License: "license", + RepositoryURL: "somerepo", + IconURL: "someicon", + Tags: []string{"tags"}, + TotalInstall: 10, + TotalStar: 98, + TotalReview: 20, + Status: "status", + StatusDetail: "status detail", + Edges: ent.NodeEdges{Versions: []*ent.NodeVersion{ + { + ID: id, + NodeID: id.String(), + Version: "v1.0.0-" + uuid.NewString(), + Changelog: "test", + Status: schema.NodeVersionStatusActive, + StatusReason: "test", + PipDependencies: []string{"test"}, + Edges: ent.NodeVersionEdges{ComfyNodes: []*ent.ComfyNode{ + { + ID: "node1", + NodeVersionID: id, + Category: "test", + Function: "test", + Description: "test", + Deprecated: false, + Experimental: false, + InputTypes: "test", + OutputIsList: []bool{true}, + ReturnNames: []string{"test"}, + ReturnTypes: "test", + CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC), + }, + { + ID: "node2", + NodeVersionID: id, + Category: "test", + Function: "test", + Description: "test", + Deprecated: false, + Experimental: false, + InputTypes: "test", + OutputIsList: []bool{true}, + ReturnNames: []string{"test"}, + ReturnTypes: "test", + CreateTime: time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC), + }, + }}, + }, + }}, } + for i := 0; i < 10; i++ { err = algolia.IndexNodes(ctx, node) require.NoError(t, err) @@ -47,6 +104,10 @@ func TestIndex(t *testing.T) { nodes, err := algolia.SearchNodes(ctx, node.Name) require.NoError(t, err) require.Len(t, nodes, 1) + // sometimes the order is mixed and assertion fail + sort.Slice(nodes[0].Edges.Versions[0].Edges.ComfyNodes, func(i, j int) bool { + return nodes[0].Edges.Versions[0].Edges.ComfyNodes[i].ID < nodes[0].Edges.Versions[0].Edges.ComfyNodes[j].ID + }) assert.Equal(t, node, nodes[0]) }) diff --git a/integration-tests/ban_test.go b/integration-tests/ban_test.go new file mode 100644 index 0000000..ccca358 --- /dev/null +++ b/integration-tests/ban_test.go @@ -0,0 +1,305 @@ +package integration + +import ( + "context" + "net/http" + "registry-backend/config" + "registry-backend/drip" + "registry-backend/ent/schema" + "registry-backend/mock/gateways" + "registry-backend/server/implementation" + drip_authorization "registry-backend/server/middleware/authorization" + "testing" + + "google.golang.org/protobuf/proto" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestBan(t *testing.T) { + clientCtx := context.Background() + client, cleanup := setupDB(t, clientCtx) + defer cleanup() + + // Initialize the Service + mockStorageService := new(gateways.MockStorageService) + mockPubsubService := new(gateways.MockPubSubService) + mockSlackService := new(gateways.MockSlackService) + mockDiscordService := new(gateways.MockDiscordService) + mockSlackService. + On("SendRegistryMessageToSlack", mock.Anything). + Return(nil) // Do nothing for all slack messsage calls. + mockAlgolia := new(gateways.MockAlgoliaService) + mockAlgolia. + On("IndexNodes", mock.Anything, mock.Anything). + Return(nil) + + impl := implementation.NewStrictServerImplementation( + client, &config.Config{}, mockStorageService, mockPubsubService, mockSlackService, mockDiscordService, mockAlgolia) + + authz := drip_authorization.NewAuthorizationManager(client, impl.RegistryService).AuthorizationMiddleware() + + t.Run("Publisher", func(t *testing.T) { + t.Run("Ban", func(t *testing.T) { + ctx, user := setUpTest(client) + + publisherId := "test-publisher" + description := "test-description" + source_code_repo := "test-source-code-repo" + website := "test-website" + support := "test-support" + logo := "test-logo" + name := "test-name" + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: &drip.Publisher{ + Id: &publisherId, + Description: &description, + SourceCodeRepo: &source_code_repo, + Website: &website, + Support: &support, + Logo: &logo, + Name: &name, + }, + }) + require.NoError(t, err, "should return created publisher") + + nodeId := "test-node" + nodeDescription := "test-node-description" + nodeAuthor := "test-node-author" + nodeLicense := "test-node-license" + nodeName := "test-node-name" + nodeTags := []string{"test-node-tag"} + icon := "https://wwww.github.com/test-icon.svg" + githubUrl := "https://www.github.com/test-github-url" + _, err = withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ + PublisherId: publisherId, + Body: &drip.Node{ + Id: &nodeId, + Name: &nodeName, + Description: &nodeDescription, + Author: &nodeAuthor, + License: &nodeLicense, + Tags: &nodeTags, + Icon: &icon, + Repository: &githubUrl, + }, + }) + require.NoError(t, err, "should return created node") + + t.Run("By Non Admin", func(t *testing.T) { + ctx, _ := setUpTest(client) + res, err := withMiddleware(authz, impl.BanPublisher)(ctx, drip.BanPublisherRequestObject{PublisherId: publisherId}) + require.NoError(t, err, "should not ban publisher") + require.IsType(t, drip.BanPublisher403JSONResponse{}, res) + }) + + t.Run("By Admin", func(t *testing.T) { + ctx, admin := setUpTest(client) + err = admin.Update().SetIsAdmin(true).Exec(clientCtx) + require.NoError(t, err) + _, err = withMiddleware(authz, impl.BanPublisher)(ctx, drip.BanPublisherRequestObject{PublisherId: publisherId}) + require.NoError(t, err) + + pub, err := client.Publisher.Get(ctx, publisherId) + require.NoError(t, err) + assert.Equal(t, schema.PublisherStatusTypeBanned, pub.Status, "should ban publisher") + user, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + assert.Equal(t, schema.UserStatusTypeBanned, user.Status, "should ban user") + node, err := client.Node.Get(ctx, nodeId) + require.NoError(t, err) + assert.Equal(t, schema.NodeStatusBanned, node.Status, "should ban node") + }) + }) + + t.Run("Access", func(t *testing.T) { + testtable := []struct { + name string + invoke func(ctx context.Context) error + }{ + { + name: "CreatePublisher", + invoke: func(ctx context.Context) error { + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{Body: &drip.Publisher{}}) + return err + }, + }, + { + name: "DeleteNodeVersion", + invoke: func(ctx context.Context) error { + _, err := withMiddleware(authz, impl.DeleteNodeVersion)(ctx, drip.DeleteNodeVersionRequestObject{}) + return err + }, + }, + } + + t.Run("Banned", func(t *testing.T) { + ctxBanned, testUserBanned := setUpTest(client) + err := testUserBanned.Update().SetStatus(schema.UserStatusTypeBanned).Exec(ctxBanned) + require.NoError(t, err) + for _, tt := range testtable { + t.Run(tt.name, func(t *testing.T) { + err = tt.invoke(ctxBanned) + require.Error(t, err, "should return error") + require.IsType(t, &echo.HTTPError{}, err, "should return echo http error") + echoErr := err.(*echo.HTTPError) + assert.Equal(t, http.StatusForbidden, echoErr.Code, "should return 403") + }) + } + }) + + t.Run("Not Banned", func(t *testing.T) { + ctx, _ := setUpTest(client) + for _, tt := range testtable { + t.Run(tt.name, func(t *testing.T) { + err := tt.invoke(ctx) + _, ok := err.(*echo.HTTPError) + assert.False(t, ok, err, "should pass the authorization middleware") + }) + } + }) + }) + }) + + t.Run("Node", func(t *testing.T) { + ctx, _ := setUpTest(client) + + publisherId := "test-publisher-1" + description := "test-description" + source_code_repo := "test-source-code-repo" + website := "test-website" + support := "test-support" + logo := "test-logo" + name := "test-name" + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: &drip.Publisher{ + Id: &publisherId, + Description: &description, + SourceCodeRepo: &source_code_repo, + Website: &website, + Support: &support, + Logo: &logo, + Name: &name, + }, + }) + require.NoError(t, err, "should return created publisher") + + nodeId := "test-node-1" + nodeDescription := "test-node-description" + nodeAuthor := "test-node-author" + nodeLicense := "test-node-license" + nodeName := "test-node-name" + nodeTags := []string{"test-node-tag"} + icon := "https://wwww.github.com/test-icon.svg" + githubUrl := "https://www.github.com/test-github-url" + _, err = withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ + PublisherId: publisherId, + Body: &drip.Node{ + Id: &nodeId, + Name: &nodeName, + Description: &nodeDescription, + Author: &nodeAuthor, + License: &nodeLicense, + Tags: &nodeTags, + Icon: &icon, + Repository: &githubUrl, + }, + }) + require.NoError(t, err, "should return created node") + + tokenName := "name" + tokenDescription := "name" + res, err := withMiddleware(authz, impl.CreatePersonalAccessToken)(ctx, drip.CreatePersonalAccessTokenRequestObject{ + PublisherId: publisherId, + Body: &drip.PersonalAccessToken{ + Name: &tokenName, + Description: &tokenDescription, + }, + }) + require.NoError(t, err, "should return created token") + require.IsType(t, drip.CreatePersonalAccessToken201JSONResponse{}, res) + pat := res.(drip.CreatePersonalAccessToken201JSONResponse).Token + + t.Run("Ban", func(t *testing.T) { + t.Run("By Non Admin", func(t *testing.T) { + ctx, _ := setUpTest(client) + res, err := withMiddleware(authz, impl.BanPublisherNode)(ctx, drip.BanPublisherNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.NoError(t, err, "should not ban publisher node") + require.IsType(t, drip.BanPublisherNode403JSONResponse{}, res) + }) + + t.Run("By Admin", func(t *testing.T) { + ctx, admin := setUpTest(client) + err = admin.Update().SetIsAdmin(true).Exec(clientCtx) + require.NoError(t, err) + _, err = withMiddleware(authz, impl.BanPublisherNode)(ctx, drip.BanPublisherNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.NoError(t, err) + + node, err := client.Node.Get(ctx, nodeId) + require.NoError(t, err) + assert.Equal(t, schema.NodeStatusBanned, node.Status, "should ban node") + }) + }) + + t.Run("Operate", func(t *testing.T) { + t.Run("Get", func(t *testing.T) { + f := withMiddleware(authz, impl.GetNode) + _, err := f(ctx, drip.GetNodeRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("Update", func(t *testing.T) { + f := withMiddleware(authz, impl.UpdateNode) + _, err := f(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("ListNodeVersion", func(t *testing.T) { + f := withMiddleware(authz, impl.ListNodeVersions) + _, err := f(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("PublishNodeVersion", func(t *testing.T) { + f := withMiddleware(authz, impl.PublishNodeVersion) + _, err := f(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: publisherId, NodeId: nodeId, + Body: &drip.PublishNodeVersionJSONRequestBody{PersonalAccessToken: *pat}, + }) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("InstallNode", func(t *testing.T) { + f := withMiddleware(authz, impl.InstallNode) + _, err := f(ctx, drip.InstallNodeRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("SearchNodes", func(t *testing.T) { + f := withMiddleware(authz, impl.SearchNodes) + res, err := f(ctx, drip.SearchNodesRequestObject{ + Params: drip.SearchNodesParams{}, + }) + require.NoError(t, err) + require.IsType(t, drip.SearchNodes200JSONResponse{}, res) + require.Empty(t, res.(drip.SearchNodes200JSONResponse).Nodes) + + res, err = f(ctx, drip.SearchNodesRequestObject{ + Params: drip.SearchNodesParams{IncludeBanned: proto.Bool(true)}, + }) + require.NoError(t, err) + require.IsType(t, drip.SearchNodes200JSONResponse{}, res) + require.NotEmpty(t, res.(drip.SearchNodes200JSONResponse).Nodes) + }) + }) + }) + +} diff --git a/integration-tests/test_util.go b/integration-tests/test_util.go index 0230023..5dad3de 100644 --- a/integration-tests/test_util.go +++ b/integration-tests/test_util.go @@ -3,19 +3,11 @@ package integration import ( "context" "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "google.golang.org/protobuf/proto" "net" "net/http" "net/http/httptest" "reflect" - "regexp" - "registry-backend/config" "registry-backend/drip" - "registry-backend/mock/gateways" - "registry-backend/server/implementation" auth "registry-backend/server/middleware/authentication" "runtime" "strings" @@ -35,207 +27,6 @@ import ( _ "github.com/lib/pq" ) -// NewStrictServerImplementationWithMocks initializes and returns the implementation with mock services. -func NewStrictServerImplementationWithMocks( - client *ent.Client, config *config.Config) *implementation.DripStrictServerImplementation { - // Mock services setup - mockStorageService := new(gateways.MockStorageService) - mockPubsubService := new(gateways.MockPubSubService) - mockSlackService := new(gateways.MockSlackService) - mockDiscordService := new(gateways.MockDiscordService) - mockAlgolia := new(gateways.MockAlgoliaService) - - // Mock service expectations - mockSlackService.On("SendRegistryMessageToSlack", mock.Anything).Return(nil) - mockAlgolia.On("IndexNodes", mock.Anything, mock.Anything).Return(nil) - - // Return the new implementation with mocked services - return implementation.NewStrictServerImplementation( - client, - config, - mockStorageService, - mockPubsubService, - mockSlackService, - mockDiscordService, - mockAlgolia, - ) -} - -func setupTestUser(client *ent.Client) (context.Context, *ent.User) { - // Create a new context and a test user - ctx := context.Background() - testUser := createTestUser(ctx, client) - - // Attach the test user to the context - ctx = decorateUserInContext(ctx, testUser) - - // Return the context and the created test user - return ctx, testUser -} - -func setupAdminUser(client *ent.Client) (context.Context, *ent.User) { - // Create a new context to isolate the test setup - ctx := context.Background() - - // Attempt to create the admin user - testUser := createAdminUser(ctx, client) - - // Decorate the user in the context - ctx = decorateUserInContext(ctx, testUser) - - // Return the decorated context and the created user - return ctx, testUser -} - -// Helper function to set up a personal access token -func setupPersonalAccessToken( - ctx context.Context, - authz drip.StrictMiddlewareFunc, - impl *implementation.DripStrictServerImplementation, - publisherId string) (*string, error) { - - tokenName := "test-token" - tokenDescription := "test-description" - res, err := withMiddleware(authz, impl.CreatePersonalAccessToken)(ctx, drip.CreatePersonalAccessTokenRequestObject{ - PublisherId: publisherId, - Body: &drip.PersonalAccessToken{ - Name: &tokenName, - Description: &tokenDescription, - }, - }) - if err != nil { - return nil, err - } - - // Extract the created token from the response - pat := res.(drip.CreatePersonalAccessToken201JSONResponse).Token - return pat, nil -} - -// Helper function to generate a valid publisher ID based on the pattern "^[a-z][a-z0-9-]*$" -func generatePublisherId() string { - // Generate a random UUID and use a portion of it for the publisher ID - rawId := uuid.New().String() - // Strip hyphens and convert to lowercase to fit the pattern - id := strings.ToLower(strings.ReplaceAll(rawId, "-", "")) - // Ensure the ID starts with a letter and follows the pattern - if match, _ := regexp.MatchString("^[a-z][a-z0-9-]*$", id); match { - return id - } - // If it doesn't match, regenerate a valid ID - return generatePublisherId() -} - -// Helper function to generate a valid node ID based on the pattern "^[a-z][a-z0-9-_]+(\\.[a-z0-9-_]+)*$" -func generateNodeId() string { - // Generate a random UUID and use a portion of it for the node ID - rawId := uuid.New().String() - // Strip hyphens and convert to lowercase to fit the pattern - id := strings.ToLower(strings.ReplaceAll(rawId, "-", "")) - // Ensure the ID starts with a letter and follows the pattern - if match, _ := regexp.MatchString("^[a-z][a-z0-9-_]+(\\.[a-z0-9-_]+)*$", id); match { - return id - } - // If it doesn't match, regenerate a valid ID - return generateNodeId() -} - -// Helper function to generate a random Publisher for testing -func randomPublisher() *drip.Publisher { - suffix := uuid.New().String() - publisherId := generatePublisherId() - - description := "test-description-" + suffix - sourceCodeRepo := "https://github.com/test-repo-" + suffix - website := "https://test-website-" + suffix + ".com" - support := "test-support-" + suffix - logo := "https://test-logo-" + suffix + ".png" - name := "test-name-" + suffix - - return &drip.Publisher{ - Id: &publisherId, - Name: &name, - Description: &description, - SourceCodeRepo: &sourceCodeRepo, - Website: &website, - Support: &support, - Logo: &logo, - } -} - -// Helper function to generate a random Node for testing -func randomNode() *drip.Node { - suffix := uuid.New().String() - nodeId := generateNodeId() - - description := "test-node-description-" + suffix - author := "test-node-author-" + suffix - license := "test-node-license-" + suffix - name := "test-node-name-" + suffix - tags := []string{"test-node-tag"} - icon := "https://www.github.com/test-icon-" + suffix + ".svg" - repository := "https://www.github.com/test-repo-" + suffix - - return &drip.Node{ - Id: &nodeId, - Name: &name, - Description: &description, - Author: &author, - License: &license, - Tags: &tags, - Icon: &icon, - Repository: &repository, - } -} - -// Helper function to generate a random NodeVersion for testing -func randomNodeVersion(revision int) *drip.NodeVersion { - suffix := uuid.New().String() - - version := fmt.Sprintf("1.0.%d", revision) - changelog := "test-changelog-" + suffix - dependencies := []string{"test-dependency-" + suffix} - - return &drip.NodeVersion{ - Version: &version, - Changelog: &changelog, - Dependencies: &dependencies, - } -} - -// Helper function to set up a publisher with a random ID -func setupPublisher( - ctx context.Context, - authz drip.StrictMiddlewareFunc, - impl *implementation.DripStrictServerImplementation) (string, error) { - - publisher := randomPublisher() - - _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ - Body: publisher, - }) - return *publisher.Id, err -} - -// Helper function to set up a node with a random ID -func setupNode( - ctx context.Context, - authz drip.StrictMiddlewareFunc, - impl *implementation.DripStrictServerImplementation, publisherId string) (string, error) { - - node := randomNode() - node.Id = proto.String(generateNodeId()) - node.Publisher = &drip.Publisher{ - Id: proto.String(publisherId), - } - - _, err := withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ - PublisherId: publisherId, - Body: node, - }) - return *node.Id, err -} - func createTestUser(ctx context.Context, client *ent.Client) *ent.User { return client.User.Create(). SetID(uuid.New().String()). @@ -342,41 +133,24 @@ func waitPortOpen(t *testing.T, host string, port string, timeout time.Duration) } func withMiddleware[R any, S any](mw drip.StrictMiddlewareFunc, h func(ctx context.Context, req R) (res S, err error)) func(ctx context.Context, req R) (res S, err error) { - // Adapt the provided handler `h` to the signature expected by the middleware. handler := func(ctx echo.Context, request interface{}) (interface{}, error) { - // Convert the `echo.Context` to a standard `context.Context` and cast the request to the expected type. return h(ctx.Request().Context(), request.(R)) } - // Use reflection to extract the operation name (function name) of the handler for logging or debugging purposes. - nameParts := strings.Split(runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name(), ".") - nameParts = strings.Split(nameParts[len(nameParts)-1], "-") - opname := nameParts[0] // Isolate the operation name. + nameA := strings.Split(runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name(), ".") + nameA = strings.Split(nameA[len(nameA)-1], "-") + opname := nameA[0] return func(ctx context.Context, req R) (res S, err error) { - // Create a simulated echo.Context with fake HTTP request and response. fakeReq := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) fakeRes := httptest.NewRecorder() fakeCtx := echo.New().NewContext(fakeReq, fakeRes) - // Wrap the adapted handler with the middleware. - wrappedHandler := mw(handler, opname) - - // Invoke the middleware-wrapped handler and cast the result to the expected response type. - result, err := wrappedHandler(fakeCtx, req) - if result == nil { - // Return a zero-value of type S if the result is nil. + f := mw(handler, opname) + r, err := f(fakeCtx, req) + if r == nil { return *new(S), err } - - // Type assert the result to the expected response type S and return. - return result.(S), err + return r.(S), err } } - -// Helper function for checking error type and code -func assertHTTPError(t *testing.T, err error, expectedCode int) { - require.IsType(t, &echo.HTTPError{}, err) - echoErr := err.(*echo.HTTPError) - assert.Equal(t, expectedCode, echoErr.Code, "should return correct HTTP error code") -} diff --git a/mapper/algolia.go b/mapper/algolia.go new file mode 100644 index 0000000..33a472f --- /dev/null +++ b/mapper/algolia.go @@ -0,0 +1,42 @@ +package mapper + +import ( + "registry-backend/ent" + "registry-backend/entity" +) + +func AlgoliaNodeFromEntNode(node *ent.Node) entity.AlgoliaNode { + n := entity.AlgoliaNode{ + ObjectID: node.ID, + Node: new(ent.Node), + } + *n.Node = *node + n.Edges = ent.NodeEdges{} + if node.Edges.Versions == nil { + return n + } + + var lv *ent.NodeVersion + for _, v := range node.Edges.Versions { + if lv == nil { + lv = v + } else if v.CreateTime.After(lv.CreateTime) { + lv = v + } + } + + n.LatestVersion = &struct { + *ent.NodeVersion + ComfyNodes map[string]*ent.ComfyNode `json:"comfy_nodes"` + }{ + NodeVersion: new(ent.NodeVersion), + ComfyNodes: make(map[string]*ent.ComfyNode, len(lv.Edges.ComfyNodes)), + } + *n.LatestVersion.NodeVersion = *lv + n.LatestVersion.NodeVersion.Edges = ent.NodeVersionEdges{} + for _, v := range lv.Edges.ComfyNodes { + n.LatestVersion.ComfyNodes[v.ID] = v + } + + return n +} diff --git a/mock/gateways/mock_algolia_service.go b/mock/gateways/mock_algolia_service.go index b4ba792..79e51be 100644 --- a/mock/gateways/mock_algolia_service.go +++ b/mock/gateways/mock_algolia_service.go @@ -12,10 +12,13 @@ var _ algolia.AlgoliaService = &MockAlgoliaService{} type MockAlgoliaService struct { mock.Mock + + LastIndexedNodes []*ent.Node } // IndexNodes implements algolia.AlgoliaService. func (m *MockAlgoliaService) IndexNodes(ctx context.Context, n ...*ent.Node) error { + m.LastIndexedNodes = n args := m.Called(ctx, n) return args.Error(0) } diff --git a/server/middleware/authorization/authorization_middleware.go b/server/middleware/authorization/authorization_middleware.go index 183df99..12f8228 100644 --- a/server/middleware/authorization/authorization_middleware.go +++ b/server/middleware/authorization/authorization_middleware.go @@ -133,20 +133,6 @@ func (m *AuthorizationManager) AuthorizationMiddleware() drip.StrictMiddlewareFu return req.(drip.DeleteNodeVersionRequestObject).NodeId }, ), - m.assertPublisherPermission( - []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, - func(req interface{}) (publisherID string) { - return req.(drip.DeleteNodeVersionRequestObject).PublisherId - }, - ), - m.assertNodeBelongsToPublisher( - func(req interface{}) (publisherID string) { - return req.(drip.DeleteNodeVersionRequestObject).PublisherId - }, - func(req interface{}) (publisherID string) { - return req.(drip.DeleteNodeVersionRequestObject).NodeId - }, - ), }, "GetNodeVersion": { m.assertNodeBanned( diff --git a/server/server.go b/server/server.go index f281733..6a08f67 100644 --- a/server/server.go +++ b/server/server.go @@ -1,7 +1,10 @@ package server import ( + monitoring "cloud.google.com/go/monitoring/apiv3/v2" "context" + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" "registry-backend/config" generated "registry-backend/drip" "registry-backend/ent" @@ -14,13 +17,8 @@ import ( "registry-backend/server/implementation" "registry-backend/server/middleware" "registry-backend/server/middleware/authentication" - drip_authorization "registry-backend/server/middleware/authorization" + "registry-backend/server/middleware/authorization" "registry-backend/server/middleware/metric" - - monitoring "cloud.google.com/go/monitoring/apiv3/v2" - "github.com/labstack/echo/v4" - labstack_middleware "github.com/labstack/echo/v4/middleware" - "github.com/rs/zerolog/log" ) type ServerDependencies struct { @@ -95,15 +93,6 @@ func (s *Server) Start() error { // Apply middleware e.Use(middleware.TracingMiddleware) - e.Use(labstack_middleware.CORSWithConfig(labstack_middleware.CORSConfig{ - AllowOrigins: []string{"*"}, - AllowMethods: []string{"*"}, - AllowHeaders: []string{"*"}, - AllowOriginFunc: func(origin string) (bool, error) { - return true, nil - }, - AllowCredentials: true, - })) e.Use(middleware.RequestLoggerMiddleware()) e.Use(middleware.ResponseLoggerMiddleware()) e.Use(metric.MetricsMiddleware(&s.Dependencies.MonitoringClient, s.Config))