Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add authentication for services calling endpoints #63

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions server/middleware/authentication/service_account_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package drip_authentication

import (
"net/http"
"strings"

"github.com/labstack/echo/v4"
"github.com/rs/zerolog/log"
"google.golang.org/api/idtoken"
)

func ServiceAccountAuthMiddleware() echo.MiddlewareFunc {
// Handlers in here should be checked by this middleware.
var checklist = map[string][]string{
"/security-scan": {"GET"},
"/nodes/reindex": {"POST"},
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ctx echo.Context) error {
// Check if the request path and method are in the checklist
path := ctx.Request().URL.Path
method := ctx.Request().Method

methods, ok := checklist[path]
if !ok {
return next(ctx)
}

for _, m := range methods {
if method == m {
ok = true
break
}
}
if !ok {
return next(ctx)
}

// validate token
authHeader := ctx.Request().Header.Get("Authorization")
token := ""
if strings.HasPrefix(authHeader, "Bearer ") {
token = authHeader[7:] // Skip the "Bearer " part
}

if token == "" {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing token")
}

log.Ctx(ctx.Request().Context()).Info().Msgf("Validating google id token %s for path %s and method %s", token, path, method)

payload, err := idtoken.Validate(ctx.Request().Context(), token, "https://api.comfy.org")
if err != nil {
log.Ctx(ctx.Request().Context()).Error().Err(err).Msg("Invalid token")
return ctx.JSON(http.StatusUnauthorized, "Invalid token")
}

email, _ := payload.Claims["email"].(string)
if email != "[email protected]" {
log.Ctx(ctx.Request().Context()).Error().Err(err).Msg("Invalid email")
return ctx.JSON(http.StatusUnauthorized, "Invalid email")
}

log.Ctx(ctx.Request().Context()).Info().Msgf("Service Account Email: %s", email)
return next(ctx)
}
}
}
67 changes: 67 additions & 0 deletions server/middleware/authentication/service_account_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package drip_authentication

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
)

func TestServiceAccountAllowList(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

middleware := ServiceAccountAuthMiddleware()

tests := []struct {
name string
path string
method string
allowed bool
}{
{"OpenAPI GET", "/openapi", "GET", true},
{"Session DELETE", "/users/sessions", "DELETE", true},
{"Health GET", "/health", "GET", true},
{"VM ANY", "/vm", "POST", true},
{"VM ANY GET", "/vm", "GET", true},
{"Artifact POST", "/upload-artifact", "POST", true},
{"Git Commit POST", "/gitcommit", "POST", true},
{"Git Commit GET", "/gitcommit", "GET", true},
{"Branch GET", "/branch", "GET", true},
{"Node Version Path POST", "/publishers/pub123/nodes/node456/versions", "POST", true},
{"Publisher POST", "/publishers", "POST", true},
{"Unauthorized Path", "/nonexistent", "GET", true},
{"Get All Nodes", "/nodes", "GET", true},
{"Install Nodes", "/nodes/node-id/install", "GET", true},

{"Reindex Nodes", "/nodes/reindex", "POST", false},
{"Reindex Nodes", "/security-scan", "GET", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
c.SetRequest(req)
handled := false
next := echo.HandlerFunc(func(c echo.Context) error {
handled = true
return nil
})
err := middleware(next)(c)
if tt.allowed {
assert.True(t, handled, "Request should be allowed through")
assert.Nil(t, err)
} else {
assert.False(t, handled, "Request should not be allowed through")
assert.NotNil(t, err)
httpError, ok := err.(*echo.HTTPError)
assert.True(t, ok, "Error should be HTTPError")
assert.Equal(t, http.StatusUnauthorized, httpError.Code)
}
})
}
}
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (s *Server) Start() error {
}

slackService := gateway.NewSlackService(s.Config)
algoliaService, err := algolia.NewFromEnvOrNoop()
algoliaService, err := algolia.NewFromEnv()
if err != nil {
return err
}
Expand Down Expand Up @@ -120,6 +120,7 @@ func (s *Server) Start() error {
// Global Middlewares
e.Use(drip_metric.MetricsMiddleware(mon, s.Config))
e.Use(drip_authentication.FirebaseAuthMiddleware(s.Client))
e.Use(drip_authentication.ServiceAccountAuthMiddleware())
e.Use(drip_authentication.JWTAdminAuthMiddleware(s.Client, s.Config.JWTSecret))
e.Use(drip_middleware.ErrorLoggingMiddleware())

Expand Down
Loading