diff --git a/integration-tests/ban_test.go b/integration-tests/ban_test.go deleted file mode 100644 index ccca358..0000000 --- a/integration-tests/ban_test.go +++ /dev/null @@ -1,305 +0,0 @@ -package integration - -import ( - "context" - "net/http" - "registry-backend/config" - "registry-backend/drip" - "registry-backend/ent/schema" - "registry-backend/mock/gateways" - "registry-backend/server/implementation" - drip_authorization "registry-backend/server/middleware/authorization" - "testing" - - "google.golang.org/protobuf/proto" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -func TestBan(t *testing.T) { - clientCtx := context.Background() - client, cleanup := setupDB(t, clientCtx) - defer cleanup() - - // Initialize the Service - mockStorageService := new(gateways.MockStorageService) - mockPubsubService := new(gateways.MockPubSubService) - mockSlackService := new(gateways.MockSlackService) - mockDiscordService := new(gateways.MockDiscordService) - mockSlackService. - On("SendRegistryMessageToSlack", mock.Anything). - Return(nil) // Do nothing for all slack messsage calls. - mockAlgolia := new(gateways.MockAlgoliaService) - mockAlgolia. - On("IndexNodes", mock.Anything, mock.Anything). - Return(nil) - - impl := implementation.NewStrictServerImplementation( - client, &config.Config{}, mockStorageService, mockPubsubService, mockSlackService, mockDiscordService, mockAlgolia) - - authz := drip_authorization.NewAuthorizationManager(client, impl.RegistryService).AuthorizationMiddleware() - - t.Run("Publisher", func(t *testing.T) { - t.Run("Ban", func(t *testing.T) { - ctx, user := setUpTest(client) - - publisherId := "test-publisher" - description := "test-description" - source_code_repo := "test-source-code-repo" - website := "test-website" - support := "test-support" - logo := "test-logo" - name := "test-name" - _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ - Body: &drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &source_code_repo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, - }, - }) - require.NoError(t, err, "should return created publisher") - - nodeId := "test-node" - nodeDescription := "test-node-description" - nodeAuthor := "test-node-author" - nodeLicense := "test-node-license" - nodeName := "test-node-name" - nodeTags := []string{"test-node-tag"} - icon := "https://wwww.github.com/test-icon.svg" - githubUrl := "https://www.github.com/test-github-url" - _, err = withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ - PublisherId: publisherId, - Body: &drip.Node{ - Id: &nodeId, - Name: &nodeName, - Description: &nodeDescription, - Author: &nodeAuthor, - License: &nodeLicense, - Tags: &nodeTags, - Icon: &icon, - Repository: &githubUrl, - }, - }) - require.NoError(t, err, "should return created node") - - t.Run("By Non Admin", func(t *testing.T) { - ctx, _ := setUpTest(client) - res, err := withMiddleware(authz, impl.BanPublisher)(ctx, drip.BanPublisherRequestObject{PublisherId: publisherId}) - require.NoError(t, err, "should not ban publisher") - require.IsType(t, drip.BanPublisher403JSONResponse{}, res) - }) - - t.Run("By Admin", func(t *testing.T) { - ctx, admin := setUpTest(client) - err = admin.Update().SetIsAdmin(true).Exec(clientCtx) - require.NoError(t, err) - _, err = withMiddleware(authz, impl.BanPublisher)(ctx, drip.BanPublisherRequestObject{PublisherId: publisherId}) - require.NoError(t, err) - - pub, err := client.Publisher.Get(ctx, publisherId) - require.NoError(t, err) - assert.Equal(t, schema.PublisherStatusTypeBanned, pub.Status, "should ban publisher") - user, err := client.User.Get(ctx, user.ID) - require.NoError(t, err) - assert.Equal(t, schema.UserStatusTypeBanned, user.Status, "should ban user") - node, err := client.Node.Get(ctx, nodeId) - require.NoError(t, err) - assert.Equal(t, schema.NodeStatusBanned, node.Status, "should ban node") - }) - }) - - t.Run("Access", func(t *testing.T) { - testtable := []struct { - name string - invoke func(ctx context.Context) error - }{ - { - name: "CreatePublisher", - invoke: func(ctx context.Context) error { - _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{Body: &drip.Publisher{}}) - return err - }, - }, - { - name: "DeleteNodeVersion", - invoke: func(ctx context.Context) error { - _, err := withMiddleware(authz, impl.DeleteNodeVersion)(ctx, drip.DeleteNodeVersionRequestObject{}) - return err - }, - }, - } - - t.Run("Banned", func(t *testing.T) { - ctxBanned, testUserBanned := setUpTest(client) - err := testUserBanned.Update().SetStatus(schema.UserStatusTypeBanned).Exec(ctxBanned) - require.NoError(t, err) - for _, tt := range testtable { - t.Run(tt.name, func(t *testing.T) { - err = tt.invoke(ctxBanned) - require.Error(t, err, "should return error") - require.IsType(t, &echo.HTTPError{}, err, "should return echo http error") - echoErr := err.(*echo.HTTPError) - assert.Equal(t, http.StatusForbidden, echoErr.Code, "should return 403") - }) - } - }) - - t.Run("Not Banned", func(t *testing.T) { - ctx, _ := setUpTest(client) - for _, tt := range testtable { - t.Run(tt.name, func(t *testing.T) { - err := tt.invoke(ctx) - _, ok := err.(*echo.HTTPError) - assert.False(t, ok, err, "should pass the authorization middleware") - }) - } - }) - }) - }) - - t.Run("Node", func(t *testing.T) { - ctx, _ := setUpTest(client) - - publisherId := "test-publisher-1" - description := "test-description" - source_code_repo := "test-source-code-repo" - website := "test-website" - support := "test-support" - logo := "test-logo" - name := "test-name" - _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ - Body: &drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &source_code_repo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, - }, - }) - require.NoError(t, err, "should return created publisher") - - nodeId := "test-node-1" - nodeDescription := "test-node-description" - nodeAuthor := "test-node-author" - nodeLicense := "test-node-license" - nodeName := "test-node-name" - nodeTags := []string{"test-node-tag"} - icon := "https://wwww.github.com/test-icon.svg" - githubUrl := "https://www.github.com/test-github-url" - _, err = withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ - PublisherId: publisherId, - Body: &drip.Node{ - Id: &nodeId, - Name: &nodeName, - Description: &nodeDescription, - Author: &nodeAuthor, - License: &nodeLicense, - Tags: &nodeTags, - Icon: &icon, - Repository: &githubUrl, - }, - }) - require.NoError(t, err, "should return created node") - - tokenName := "name" - tokenDescription := "name" - res, err := withMiddleware(authz, impl.CreatePersonalAccessToken)(ctx, drip.CreatePersonalAccessTokenRequestObject{ - PublisherId: publisherId, - Body: &drip.PersonalAccessToken{ - Name: &tokenName, - Description: &tokenDescription, - }, - }) - require.NoError(t, err, "should return created token") - require.IsType(t, drip.CreatePersonalAccessToken201JSONResponse{}, res) - pat := res.(drip.CreatePersonalAccessToken201JSONResponse).Token - - t.Run("Ban", func(t *testing.T) { - t.Run("By Non Admin", func(t *testing.T) { - ctx, _ := setUpTest(client) - res, err := withMiddleware(authz, impl.BanPublisherNode)(ctx, drip.BanPublisherNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) - require.NoError(t, err, "should not ban publisher node") - require.IsType(t, drip.BanPublisherNode403JSONResponse{}, res) - }) - - t.Run("By Admin", func(t *testing.T) { - ctx, admin := setUpTest(client) - err = admin.Update().SetIsAdmin(true).Exec(clientCtx) - require.NoError(t, err) - _, err = withMiddleware(authz, impl.BanPublisherNode)(ctx, drip.BanPublisherNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) - require.NoError(t, err) - - node, err := client.Node.Get(ctx, nodeId) - require.NoError(t, err) - assert.Equal(t, schema.NodeStatusBanned, node.Status, "should ban node") - }) - }) - - t.Run("Operate", func(t *testing.T) { - t.Run("Get", func(t *testing.T) { - f := withMiddleware(authz, impl.GetNode) - _, err := f(ctx, drip.GetNodeRequestObject{NodeId: nodeId}) - require.Error(t, err) - require.IsType(t, &echo.HTTPError{}, err) - assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) - }) - t.Run("Update", func(t *testing.T) { - f := withMiddleware(authz, impl.UpdateNode) - _, err := f(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) - require.Error(t, err) - require.IsType(t, &echo.HTTPError{}, err) - assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) - }) - t.Run("ListNodeVersion", func(t *testing.T) { - f := withMiddleware(authz, impl.ListNodeVersions) - _, err := f(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) - require.Error(t, err) - require.IsType(t, &echo.HTTPError{}, err) - assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) - }) - t.Run("PublishNodeVersion", func(t *testing.T) { - f := withMiddleware(authz, impl.PublishNodeVersion) - _, err := f(ctx, drip.PublishNodeVersionRequestObject{ - PublisherId: publisherId, NodeId: nodeId, - Body: &drip.PublishNodeVersionJSONRequestBody{PersonalAccessToken: *pat}, - }) - require.Error(t, err) - require.IsType(t, &echo.HTTPError{}, err) - assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) - }) - t.Run("InstallNode", func(t *testing.T) { - f := withMiddleware(authz, impl.InstallNode) - _, err := f(ctx, drip.InstallNodeRequestObject{NodeId: nodeId}) - require.Error(t, err) - require.IsType(t, &echo.HTTPError{}, err) - assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) - }) - t.Run("SearchNodes", func(t *testing.T) { - f := withMiddleware(authz, impl.SearchNodes) - res, err := f(ctx, drip.SearchNodesRequestObject{ - Params: drip.SearchNodesParams{}, - }) - require.NoError(t, err) - require.IsType(t, drip.SearchNodes200JSONResponse{}, res) - require.Empty(t, res.(drip.SearchNodes200JSONResponse).Nodes) - - res, err = f(ctx, drip.SearchNodesRequestObject{ - Params: drip.SearchNodesParams{IncludeBanned: proto.Bool(true)}, - }) - require.NoError(t, err) - require.IsType(t, drip.SearchNodes200JSONResponse{}, res) - require.NotEmpty(t, res.(drip.SearchNodes200JSONResponse).Nodes) - }) - }) - }) - -} diff --git a/integration-tests/node_ban_test.go b/integration-tests/node_ban_test.go new file mode 100644 index 0000000..defdc3a --- /dev/null +++ b/integration-tests/node_ban_test.go @@ -0,0 +1,142 @@ +package integration + +import ( + "context" + "google.golang.org/protobuf/proto" + "net/http" + "registry-backend/config" + "registry-backend/drip" + "registry-backend/ent/schema" + drip_authorization "registry-backend/server/middleware/authorization" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNodeBan(t *testing.T) { + clientCtx := context.Background() + client, cleanup := setupDB(t, clientCtx) + defer cleanup() + + // Setup the mock services and server + impl := NewStrictServerImplementationWithMocks(client, &config.Config{}) + authz := drip_authorization.NewAuthorizationManager(client, impl.RegistryService).AuthorizationMiddleware() + + t.Run("Node Ban Tests", func(t *testing.T) { + userCtx, _ := setupTestUser(client) + adminCtx, _ := setupAdminUser(client) + + // Setup a test publisher + publisherId, err := setupPublisher(userCtx, authz, impl) + require.NoError(t, err, "should set up publisher") + + // Setup a test node + nodeId, err := setupNode(userCtx, authz, impl, publisherId) + require.NoError(t, err, "should set up node") + + // Setup a personal access token + pat, err := setupPersonalAccessToken(userCtx, authz, impl, publisherId) + require.NoError(t, err, "should set up personal access token") + + t.Run("Ban node by non-admin", func(t *testing.T) { + // Attempt to ban the node as a non-admin user + res, err := withMiddleware(authz, impl.BanPublisherNode)( + userCtx, drip.BanPublisherNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.NoError(t, err, "should not ban node by non-admin") + require.IsType(t, drip.BanPublisherNode403JSONResponse{}, res) + }) + + t.Run("Ban node by admin", func(t *testing.T) { + // Attempt to ban the node as an admin user + _, err := withMiddleware(authz, impl.BanPublisherNode)( + adminCtx, drip.BanPublisherNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.NoError(t, err) + + // Verify that the node is banned + node, err := client.Node.Get(adminCtx, nodeId) + require.NoError(t, err) + assert.Equal(t, schema.NodeStatusBanned, node.Status, "should ban node") + }) + + t.Run("Calling endpoints with a banned node", func(t *testing.T) { + // endpoints to test the authorization middleware + testEndpoints := []struct { + name string + invoke func(ctx context.Context) error + }{ + {"GetNode", func(ctx context.Context) error { + f := withMiddleware(authz, impl.GetNode) + _, err := f(ctx, drip.GetNodeRequestObject{NodeId: nodeId}) + return err + }}, + {"UpdateNode", func(ctx context.Context) error { + f := withMiddleware(authz, impl.UpdateNode) + _, err := f(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + return err + }}, + {"ListNodeVersions", func(ctx context.Context) error { + f := withMiddleware(authz, impl.ListNodeVersions) + _, err := f(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) + return err + }}, + {"PublishNodeVersion", func(ctx context.Context) error { + f := withMiddleware(authz, impl.PublishNodeVersion) + _, err := f(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: publisherId, NodeId: nodeId, + Body: &drip.PublishNodeVersionJSONRequestBody{PersonalAccessToken: *pat}, + }) + return err + }}, + {"InstallNode", func(ctx context.Context) error { + f := withMiddleware(authz, impl.InstallNode) + _, err := f(ctx, drip.InstallNodeRequestObject{NodeId: nodeId}) + return err + }}, + } + + for _, tc := range testEndpoints { + t.Run(tc.name, func(t *testing.T) { + err := tc.invoke(userCtx) + require.Error(t, err, "should return error") + assertHTTPError(t, err, http.StatusForbidden) + }) + } + }) + + t.Run("SearchNodes with banned node", func(t *testing.T) { + // Step 1: Perform a search without including banned nodes (default behavior). + f := withMiddleware(authz, impl.SearchNodes) + res, err := f(userCtx, drip.SearchNodesRequestObject{ + Params: drip.SearchNodesParams{IncludeBanned: proto.Bool(false)}, // Explicitly do not include banned nodes. + }) + require.NoError(t, err) + require.IsType(t, drip.SearchNodes200JSONResponse{}, res) + + // Assert that no nodes are returned when IncludeBanned is false (since the node should be banned). + searchResponse := res.(drip.SearchNodes200JSONResponse) + require.Empty(t, searchResponse.Nodes, "Search should not include banned nodes") + + // Step 2: Perform a search including banned nodes. + res, err = f(userCtx, drip.SearchNodesRequestObject{ + Params: drip.SearchNodesParams{IncludeBanned: proto.Bool(true)}, // Explicitly include banned nodes. + }) + require.NoError(t, err) + require.IsType(t, drip.SearchNodes200JSONResponse{}, res) + + // Assert that nodes are returned when IncludeBanned is true. + searchResponse = res.(drip.SearchNodes200JSONResponse) + require.NotEmpty(t, searchResponse.Nodes, "Search should include banned nodes") + + // Step 3: Assert that the banned node is included in the search result. + foundBannedNode := false + for _, node := range *searchResponse.Nodes { + if *node.Id == nodeId { + foundBannedNode = true + break + } + } + require.True(t, foundBannedNode, "The banned node should be present in the search results") + }) + }) +} diff --git a/integration-tests/publisher_ban_test.go b/integration-tests/publisher_ban_test.go new file mode 100644 index 0000000..55caeeb --- /dev/null +++ b/integration-tests/publisher_ban_test.go @@ -0,0 +1,80 @@ +package integration + +import ( + "context" + "net/http" + "registry-backend/config" + "registry-backend/drip" + "registry-backend/ent/schema" + authorization "registry-backend/server/middleware/authorization" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPublisherBan(t *testing.T) { + clientCtx := context.Background() + client, cleanup := setupDB(t, clientCtx) + defer cleanup() + + // Setup the mock services and server + impl := NewStrictServerImplementationWithMocks(client, &config.Config{}) + authz := authorization.NewAuthorizationManager(client, impl.RegistryService).AuthorizationMiddleware() + + t.Run("Publisher Ban Tests", func(t *testing.T) { + userCtx, _ := setupTestUser(client) + adminCtx, _ := setupAdminUser(client) + + // Setup a test publisher + publisherId, err := setupPublisher(userCtx, authz, impl) + require.NoError(t, err, "should set up publisher") + + // endpoints to test the authorization middleware + testEndpoints := []struct { + name string + invoke func(ctx context.Context) error + }{ + {"CreatePublisher", func(ctx context.Context) error { + _, err := setupPublisher(ctx, authz, impl) + return err + }}, + {"DeleteNodeVersion", func(ctx context.Context) error { + _, err := withMiddleware(authz, impl.DeleteNodeVersion)( + ctx, drip.DeleteNodeVersionRequestObject{ + PublisherId: publisherId, + }) + return err + }}, + } + + t.Run("Ban publisher by non-admin", func(t *testing.T) { + // Use the same publisher and node created earlier + res, err := withMiddleware(authz, impl.BanPublisher)( + userCtx, drip.BanPublisherRequestObject{PublisherId: publisherId}) + require.NoError(t, err, "should not ban publisher by non-admin") + require.IsType(t, drip.BanPublisher403JSONResponse{}, res) + }) + + t.Run("Ban publisher by admin", func(t *testing.T) { + // Use the same publisher and node created earlier + _, err := withMiddleware(authz, impl.BanPublisher)( + adminCtx, drip.BanPublisherRequestObject{PublisherId: publisherId}) + require.NoError(t, err) + + pub, err := client.Publisher.Get(userCtx, publisherId) + require.NoError(t, err) + assert.Equal(t, schema.PublisherStatusTypeBanned, pub.Status, "should ban publisher") + }) + + t.Run("Calling endpoints with a banned user", func(t *testing.T) { + for _, tc := range testEndpoints { + t.Run(tc.name, func(t *testing.T) { + err := tc.invoke(userCtx) + require.Error(t, err, "should return error") + assertHTTPError(t, err, http.StatusForbidden) + }) + } + }) + }) +} diff --git a/integration-tests/registry_integration_test.go b/integration-tests/registry_integration_test.go index 9764fa2..fdb70a7 100644 --- a/integration-tests/registry_integration_test.go +++ b/integration-tests/registry_integration_test.go @@ -23,85 +23,12 @@ import ( "github.com/labstack/echo/v4" strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" ) -func setUpTest(client *ent.Client) (context.Context, *ent.User) { - ctx := context.Background() - // create a User and attach to context - testUser := createTestUser(ctx, client) - ctx = decorateUserInContext(ctx, testUser) - return ctx, testUser -} - -func setUpAdminTest(client *ent.Client) (context.Context, *ent.User) { - ctx := context.Background() - testUser := createAdminUser(ctx, client) - ctx = decorateUserInContext(ctx, testUser) - return ctx, testUser -} - -func randomPublisher() *drip.Publisher { - suffix := uuid.New().String() - publisherId := "test-publisher-" + suffix - description := "test-description" + suffix - source_code_repo := "test-source-code-repo" + suffix - website := "test-website" + suffix - support := "test-support" + suffix - logo := "test-logo" + suffix - name := "test-name" + suffix - - return &drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &source_code_repo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, - } -} - -func randomNode() *drip.Node { - suffix := uuid.New().String() - nodeId := "test-node" + suffix - nodeDescription := "test-node-description" + suffix - nodeAuthor := "test-node-author" + suffix - nodeLicense := "test-node-license" + suffix - nodeName := "test-node-name" + suffix - nodeTags := []string{"test-node-tag"} - icon := "https://wwww.github.com/test-icon-" + suffix + ".svg" - githubUrl := "https://www.github.com/test-github-url-" + suffix - - return &drip.Node{ - Id: &nodeId, - Name: &nodeName, - Description: &nodeDescription, - Author: &nodeAuthor, - License: &nodeLicense, - Tags: &nodeTags, - Icon: &icon, - Repository: &githubUrl, - } -} - -func randomNodeVersion(revision int) *drip.NodeVersion { - suffix := uuid.New().String() - - version := fmt.Sprintf("1.0.%d", revision) - changelog := "test-changelog-" + suffix - dependencies := []string{"test-dependency" + suffix} - return &drip.NodeVersion{ - Version: &version, - Changelog: &changelog, - Dependencies: &dependencies, - } -} - type mockedImpl struct { *implementation.DripStrictServerImplementation @@ -156,7 +83,7 @@ func TestRegistryPublisher(t *testing.T) { defer cleanup() impl, authz := newMockedImpl(client, &config.Config{}) - ctx, testUser := setUpTest(client) + ctx, testUser := setupTestUser(client) pub := randomPublisher() createPublisherResponse, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ @@ -279,7 +206,7 @@ func TestRegistryPersonalAccessToken(t *testing.T) { defer cleanup() impl, authz := newMockedImpl(client, &config.Config{}) - ctx, _ := setUpTest(client) + ctx, _ := setupTestUser(client) pub := randomPublisher() _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ Body: pub, @@ -320,7 +247,7 @@ func TestRegistryNode(t *testing.T) { defer cleanup() impl, authz := newMockedImpl(client, &config.Config{}) - ctx, _ := setUpTest(client) + ctx, _ := setupTestUser(client) pub := randomPublisher() _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ @@ -478,7 +405,7 @@ func TestRegistryNodeVersion(t *testing.T) { defer cleanup() impl, authz := newMockedImpl(client, &config.Config{}) - ctx, _ := setUpTest(client) + ctx, _ := setupTestUser(client) pub := randomPublisher() respub, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ @@ -527,7 +454,7 @@ func TestRegistryNodeVersion(t *testing.T) { createdNodeVersion = *res.(drip.PublishNodeVersion201JSONResponse).NodeVersion // Needed for downstream tests. t.Run("Admin Update", func(t *testing.T) { - adminCtx, _ := setUpAdminTest(client) + adminCtx, _ := setupAdminUser(client) activeStatus := drip.NodeVersionStatusActive adminUpdateNodeVersionResp, err := impl.AdminUpdateNodeVersion(adminCtx, drip.AdminUpdateNodeVersionRequestObject{ NodeId: *node.Id, @@ -860,7 +787,7 @@ func TestRegistryComfyNode(t *testing.T) { defer cleanup() impl, authz := newMockedImpl(client, &config.Config{}) - ctx, _ := setUpTest(client) + ctx, _ := setupTestUser(client) ctx = drip_logging.SetupLogger().WithContext(ctx) pub := randomPublisher() diff --git a/integration-tests/test_util.go b/integration-tests/test_util.go index 5dad3de..0230023 100644 --- a/integration-tests/test_util.go +++ b/integration-tests/test_util.go @@ -3,11 +3,19 @@ package integration import ( "context" "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" "net" "net/http" "net/http/httptest" "reflect" + "regexp" + "registry-backend/config" "registry-backend/drip" + "registry-backend/mock/gateways" + "registry-backend/server/implementation" auth "registry-backend/server/middleware/authentication" "runtime" "strings" @@ -27,6 +35,207 @@ import ( _ "github.com/lib/pq" ) +// NewStrictServerImplementationWithMocks initializes and returns the implementation with mock services. +func NewStrictServerImplementationWithMocks( + client *ent.Client, config *config.Config) *implementation.DripStrictServerImplementation { + // Mock services setup + mockStorageService := new(gateways.MockStorageService) + mockPubsubService := new(gateways.MockPubSubService) + mockSlackService := new(gateways.MockSlackService) + mockDiscordService := new(gateways.MockDiscordService) + mockAlgolia := new(gateways.MockAlgoliaService) + + // Mock service expectations + mockSlackService.On("SendRegistryMessageToSlack", mock.Anything).Return(nil) + mockAlgolia.On("IndexNodes", mock.Anything, mock.Anything).Return(nil) + + // Return the new implementation with mocked services + return implementation.NewStrictServerImplementation( + client, + config, + mockStorageService, + mockPubsubService, + mockSlackService, + mockDiscordService, + mockAlgolia, + ) +} + +func setupTestUser(client *ent.Client) (context.Context, *ent.User) { + // Create a new context and a test user + ctx := context.Background() + testUser := createTestUser(ctx, client) + + // Attach the test user to the context + ctx = decorateUserInContext(ctx, testUser) + + // Return the context and the created test user + return ctx, testUser +} + +func setupAdminUser(client *ent.Client) (context.Context, *ent.User) { + // Create a new context to isolate the test setup + ctx := context.Background() + + // Attempt to create the admin user + testUser := createAdminUser(ctx, client) + + // Decorate the user in the context + ctx = decorateUserInContext(ctx, testUser) + + // Return the decorated context and the created user + return ctx, testUser +} + +// Helper function to set up a personal access token +func setupPersonalAccessToken( + ctx context.Context, + authz drip.StrictMiddlewareFunc, + impl *implementation.DripStrictServerImplementation, + publisherId string) (*string, error) { + + tokenName := "test-token" + tokenDescription := "test-description" + res, err := withMiddleware(authz, impl.CreatePersonalAccessToken)(ctx, drip.CreatePersonalAccessTokenRequestObject{ + PublisherId: publisherId, + Body: &drip.PersonalAccessToken{ + Name: &tokenName, + Description: &tokenDescription, + }, + }) + if err != nil { + return nil, err + } + + // Extract the created token from the response + pat := res.(drip.CreatePersonalAccessToken201JSONResponse).Token + return pat, nil +} + +// Helper function to generate a valid publisher ID based on the pattern "^[a-z][a-z0-9-]*$" +func generatePublisherId() string { + // Generate a random UUID and use a portion of it for the publisher ID + rawId := uuid.New().String() + // Strip hyphens and convert to lowercase to fit the pattern + id := strings.ToLower(strings.ReplaceAll(rawId, "-", "")) + // Ensure the ID starts with a letter and follows the pattern + if match, _ := regexp.MatchString("^[a-z][a-z0-9-]*$", id); match { + return id + } + // If it doesn't match, regenerate a valid ID + return generatePublisherId() +} + +// Helper function to generate a valid node ID based on the pattern "^[a-z][a-z0-9-_]+(\\.[a-z0-9-_]+)*$" +func generateNodeId() string { + // Generate a random UUID and use a portion of it for the node ID + rawId := uuid.New().String() + // Strip hyphens and convert to lowercase to fit the pattern + id := strings.ToLower(strings.ReplaceAll(rawId, "-", "")) + // Ensure the ID starts with a letter and follows the pattern + if match, _ := regexp.MatchString("^[a-z][a-z0-9-_]+(\\.[a-z0-9-_]+)*$", id); match { + return id + } + // If it doesn't match, regenerate a valid ID + return generateNodeId() +} + +// Helper function to generate a random Publisher for testing +func randomPublisher() *drip.Publisher { + suffix := uuid.New().String() + publisherId := generatePublisherId() + + description := "test-description-" + suffix + sourceCodeRepo := "https://github.com/test-repo-" + suffix + website := "https://test-website-" + suffix + ".com" + support := "test-support-" + suffix + logo := "https://test-logo-" + suffix + ".png" + name := "test-name-" + suffix + + return &drip.Publisher{ + Id: &publisherId, + Name: &name, + Description: &description, + SourceCodeRepo: &sourceCodeRepo, + Website: &website, + Support: &support, + Logo: &logo, + } +} + +// Helper function to generate a random Node for testing +func randomNode() *drip.Node { + suffix := uuid.New().String() + nodeId := generateNodeId() + + description := "test-node-description-" + suffix + author := "test-node-author-" + suffix + license := "test-node-license-" + suffix + name := "test-node-name-" + suffix + tags := []string{"test-node-tag"} + icon := "https://www.github.com/test-icon-" + suffix + ".svg" + repository := "https://www.github.com/test-repo-" + suffix + + return &drip.Node{ + Id: &nodeId, + Name: &name, + Description: &description, + Author: &author, + License: &license, + Tags: &tags, + Icon: &icon, + Repository: &repository, + } +} + +// Helper function to generate a random NodeVersion for testing +func randomNodeVersion(revision int) *drip.NodeVersion { + suffix := uuid.New().String() + + version := fmt.Sprintf("1.0.%d", revision) + changelog := "test-changelog-" + suffix + dependencies := []string{"test-dependency-" + suffix} + + return &drip.NodeVersion{ + Version: &version, + Changelog: &changelog, + Dependencies: &dependencies, + } +} + +// Helper function to set up a publisher with a random ID +func setupPublisher( + ctx context.Context, + authz drip.StrictMiddlewareFunc, + impl *implementation.DripStrictServerImplementation) (string, error) { + + publisher := randomPublisher() + + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: publisher, + }) + return *publisher.Id, err +} + +// Helper function to set up a node with a random ID +func setupNode( + ctx context.Context, + authz drip.StrictMiddlewareFunc, + impl *implementation.DripStrictServerImplementation, publisherId string) (string, error) { + + node := randomNode() + node.Id = proto.String(generateNodeId()) + node.Publisher = &drip.Publisher{ + Id: proto.String(publisherId), + } + + _, err := withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ + PublisherId: publisherId, + Body: node, + }) + return *node.Id, err +} + func createTestUser(ctx context.Context, client *ent.Client) *ent.User { return client.User.Create(). SetID(uuid.New().String()). @@ -133,24 +342,41 @@ func waitPortOpen(t *testing.T, host string, port string, timeout time.Duration) } func withMiddleware[R any, S any](mw drip.StrictMiddlewareFunc, h func(ctx context.Context, req R) (res S, err error)) func(ctx context.Context, req R) (res S, err error) { + // Adapt the provided handler `h` to the signature expected by the middleware. handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + // Convert the `echo.Context` to a standard `context.Context` and cast the request to the expected type. return h(ctx.Request().Context(), request.(R)) } - nameA := strings.Split(runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name(), ".") - nameA = strings.Split(nameA[len(nameA)-1], "-") - opname := nameA[0] + // Use reflection to extract the operation name (function name) of the handler for logging or debugging purposes. + nameParts := strings.Split(runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name(), ".") + nameParts = strings.Split(nameParts[len(nameParts)-1], "-") + opname := nameParts[0] // Isolate the operation name. return func(ctx context.Context, req R) (res S, err error) { + // Create a simulated echo.Context with fake HTTP request and response. fakeReq := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) fakeRes := httptest.NewRecorder() fakeCtx := echo.New().NewContext(fakeReq, fakeRes) - f := mw(handler, opname) - r, err := f(fakeCtx, req) - if r == nil { + // Wrap the adapted handler with the middleware. + wrappedHandler := mw(handler, opname) + + // Invoke the middleware-wrapped handler and cast the result to the expected response type. + result, err := wrappedHandler(fakeCtx, req) + if result == nil { + // Return a zero-value of type S if the result is nil. return *new(S), err } - return r.(S), err + + // Type assert the result to the expected response type S and return. + return result.(S), err } } + +// Helper function for checking error type and code +func assertHTTPError(t *testing.T, err error, expectedCode int) { + require.IsType(t, &echo.HTTPError{}, err) + echoErr := err.(*echo.HTTPError) + assert.Equal(t, expectedCode, echoErr.Code, "should return correct HTTP error code") +}