Skip to content

Commit

Permalink
feat: cors support for gcp and for custom middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
davemooreuws committed Nov 4, 2023
1 parent 5c7ffde commit 131ccd5
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 53 deletions.
2 changes: 1 addition & 1 deletion cloud/aws/deploy/api/apigateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (

"github.com/nitrictech/nitric/cloud/aws/deploy/config"
"github.com/nitrictech/nitric/cloud/aws/deploy/exec"
"github.com/nitrictech/nitric/cloud/common/deploy/cors"
"github.com/nitrictech/nitric/cloud/common/cors"
common "github.com/nitrictech/nitric/cloud/common/deploy/tags"
v1 "github.com/nitrictech/nitric/core/pkg/api/nitric/v1"
)
Expand Down
2 changes: 1 addition & 1 deletion cloud/azure/deploy/api/apimanagement.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"html/template"
"strings"

"github.com/nitrictech/nitric/cloud/common/deploy/cors"
"github.com/nitrictech/nitric/cloud/common/cors"
"github.com/nitrictech/nitric/cloud/common/deploy/resources"

"github.com/getkin/kin-openapi/openapi3"
Expand Down
134 changes: 134 additions & 0 deletions cloud/common/cors/cors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2021 Nitric Technologies Pty Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cors

import (
"encoding/json"
"fmt"
"os"
"strconv"
"strings"

"github.com/imdario/mergo"
"github.com/valyala/fasthttp"

base_http "github.com/nitrictech/nitric/cloud/common/runtime/gateway"
v1 "github.com/nitrictech/nitric/core/pkg/api/nitric/v1"
"github.com/nitrictech/nitric/core/pkg/worker/pool"
)

func GetCorsConfig(vals *v1.ApiCorsDefinition) (*v1.ApiCorsDefinition, error) {
defaultVal := &v1.ApiCorsDefinition{
AllowCredentials: false,
AllowOrigins: []string{"*"},
AllowHeaders: []string{"Content-Type", "Authorization"},
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
ExposeHeaders: []string{},
MaxAge: 300,
}

if err := mergo.Merge(defaultVal, vals, mergo.WithOverride); err != nil {
return nil, err
}

return defaultVal, nil
}

// Used for GCP and Local CORs with fasthttp headers
func GetCorsHeaders(config *v1.ApiCorsDefinition) (*map[string]string, error) {
corsHeaders := map[string]string{}

corsConfig, err := GetCorsConfig(config)
if err != nil {
return nil, err
}

corsHeaders["Access-Control-Allow-Credentials"] = strconv.FormatBool(corsConfig.GetAllowCredentials())

if len(corsConfig.GetAllowOrigins()) > 0 {
corsHeaders["Access-Control-Allow-Origin"] = strings.Join(corsConfig.GetAllowOrigins(), ",")
}

if len(corsConfig.GetAllowMethods()) > 0 {
corsHeaders["Access-Control-Allow-Methods"] = strings.Join(corsConfig.GetAllowMethods(), ",")
}

if len(corsConfig.GetAllowHeaders()) > 0 {
corsHeaders["Access-Control-Allow-Headers"] = strings.Join(corsConfig.GetAllowHeaders(), ",")
}

if len(corsConfig.GetExposeHeaders()) > 0 {
corsHeaders["Access-Control-Expose-Headers"] = strings.Join(corsConfig.GetExposeHeaders(), ",")
}

corsHeaders["Access-Control-Max-Age"] = strconv.FormatInt(int64(corsConfig.GetMaxAge()), 10)

return &corsHeaders, nil
}

func GetEnvKey(name string) string {
return fmt.Sprintf("NITRIC_CORS_%s", strings.ToUpper(name))
}

func CreateCorsMiddleware(cache map[string]map[string]string) base_http.HttpMiddleware {
return func(rc *fasthttp.RequestCtx, wp pool.WorkerPool) bool {
api := string(rc.Request.Header.Peek("x-nitric-api"))
method := string(rc.Request.Header.Method())

if cache[api] != nil {
applyCorsHeaders(rc, cache[api])
return true
}

corsHeaders, err := getCorsHeadersForAPI(api)
if err != nil {
if method == "OPTIONS" {
rc.Response.SetStatusCode(fasthttp.StatusMethodNotAllowed)
}

return true
}

cache[api] = corsHeaders

applyCorsHeaders(rc, corsHeaders)

return true
}
}

func getCorsHeadersForAPI(name string) (map[string]string, error) {
env := os.Getenv(GetEnvKey(name))

if env == "" {
return nil, fmt.Errorf("no cors env var found for api %s", name)
}

headers := make(map[string]string)

// Unmarshal the JSON string into the map
err := json.Unmarshal([]byte(env), &headers)
if err != nil {
return nil, err
}

return headers, nil
}

func applyCorsHeaders(rc *fasthttp.RequestCtx, corsHeaders map[string]string) {
for k, v := range corsHeaders {
rc.Response.Header.Add(k, v)
}
}
37 changes: 0 additions & 37 deletions cloud/common/deploy/cors/cors.go

This file was deleted.

56 changes: 47 additions & 9 deletions cloud/common/runtime/gateway/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/nitrictech/nitric/core/pkg/plugins/gateway"
"github.com/nitrictech/nitric/core/pkg/span"
"github.com/nitrictech/nitric/core/pkg/utils"
"github.com/nitrictech/nitric/core/pkg/worker"
"github.com/nitrictech/nitric/core/pkg/worker/pool"
)

Expand Down Expand Up @@ -84,17 +85,8 @@ func HttpHeadersToMap(rh *fasthttp.RequestHeader) map[string][]string {

func (s *BaseHttpGateway) httpHandler(workerPool pool.WorkerPool) func(ctx *fasthttp.RequestCtx) {
return func(rc *fasthttp.RequestCtx) {
if s.mw != nil {
if !s.mw(rc, workerPool) {
// middleware has indicated that is has processed the request
// so we can exit here
return
}
}

headerMap := HttpHeadersToMap(&rc.Request.Header)

// httpTrigger := triggers.FromHttpRequest(rc)
headers := map[string]*v1.HeaderValue{}
for k, v := range headerMap {
headers[k] = &v1.HeaderValue{Value: v}
Expand Down Expand Up @@ -127,10 +119,51 @@ func (s *BaseHttpGateway) httpHandler(workerPool pool.WorkerPool) func(ctx *fast
Trigger: httpTrigger,
})
if err != nil {
// if no worker and middleware is enabled, allow options requests to hit the middleware
if s.mw != nil && string(rc.Request.Header.Method()) == "OPTIONS" {
// find worker based on path
wrkrs := workerPool.GetWorkers(&pool.GetWorkerOptions{
Filter: func(w worker.Worker) bool {
if api, ok := w.(*worker.RouteWorker); ok {
_, err := api.ExtractPathParams(string(rc.URI().PathOriginal()))

return err == nil
}

return false
},
})

if len(wrkrs) > 0 {
rw, ok := wrkrs[0].(*worker.RouteWorker)
if ok {
s.addApiHeader(rc, rw)
}

if s.mw != nil {
s.mw(rc, workerPool)
return
}
}
}

rc.Error("Unable to get worker to handle request", 500)
return
}

rw, ok := wrkr.(*worker.RouteWorker)
if ok {
s.addApiHeader(rc, rw)
}

if s.mw != nil {
if !s.mw(rc, workerPool) {
// middleware has indicated that is has processed the request
// so we can exit here
return
}
}

response, err := wrkr.HandleTrigger(span.FromHeaders(context.TODO(), headerMap), httpTrigger)
if err != nil {
rc.Error(fmt.Sprintf("Error handling HTTP Request: %v", err), 500)
Expand All @@ -157,6 +190,11 @@ func (s *BaseHttpGateway) httpHandler(workerPool pool.WorkerPool) func(ctx *fast
}
}

func (s *BaseHttpGateway) addApiHeader(rc *fasthttp.RequestCtx, worker *worker.RouteWorker) {
// this header can be used by runtime plugins to determine the api
rc.Request.Header.Add("X-Nitric-Api", worker.Api())
}

func (s *BaseHttpGateway) Start(pool pool.WorkerPool) error {
r := router.New()

Expand Down
43 changes: 43 additions & 0 deletions cloud/gcp/deploy/gateway/apigateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/nitrictech/nitric/cloud/common/deploy/utils"
"github.com/nitrictech/nitric/cloud/gcp/deploy/exec"
"github.com/nitrictech/nitric/cloud/gcp/deploy/iam"
v1 "github.com/nitrictech/nitric/core/pkg/api/nitric/v1"
)

type ApiGatewayArgs struct {
Expand All @@ -44,6 +45,7 @@ type ApiGatewayArgs struct {
OpenAPISpec *openapi3.T
Functions map[string]*exec.CloudRunner
SecuritySchemes openapi3.SecuritySchemes
Cors *v1.ApiCorsDefinition
}

type ApiGateway struct {
Expand Down Expand Up @@ -170,6 +172,12 @@ func NewApiGateway(ctx *pulumi.Context, name string, args *ApiGatewayArgs, opts
p.Put = gcpOperation(p.Put, naps)
p.Delete = gcpOperation(p.Delete, naps)
p.Options = gcpOperation(p.Options, naps)

// if cors and no options defined, set it
if args.Cors != nil && p.Options == nil {
p.Options = createCorsOp(p, naps)
}

v2doc.Paths[k] = p
}

Expand Down Expand Up @@ -315,3 +323,38 @@ func gcpOperation(op *openapi2.Operation, urls map[string]string) *openapi2.Oper

return op
}

func createCorsOp(p *openapi2.PathItem, naps map[string]string) *openapi2.Operation {
var extensions map[string]interface{}
var name string
var responses map[string]*openapi2.Response

// get extensions
if p.Get != nil {
extensions = p.Get.Extensions
responses = p.Get.Responses
name = strings.TrimSuffix(p.Get.OperationID, "get")
} else if p.Post != nil {
extensions = p.Post.Extensions
responses = p.Post.Responses
name = strings.TrimSuffix(p.Get.OperationID, "post")
} else if p.Patch != nil {
extensions = p.Patch.Extensions
responses = p.Patch.Responses
name = strings.TrimSuffix(p.Get.OperationID, "patch")
} else if p.Put != nil {
extensions = p.Get.Extensions
responses = p.Get.Responses
name = strings.TrimSuffix(p.Get.OperationID, "put")
} else if p.Delete != nil {
extensions = p.Delete.Extensions
responses = p.Delete.Responses
name = strings.TrimSuffix(p.Get.OperationID, "delete")
}

return gcpOperation(&openapi2.Operation{
OperationID: fmt.Sprintf("%soptions", name),
Extensions: extensions,
Responses: responses,
}, naps)
}
Loading

0 comments on commit 131ccd5

Please sign in to comment.