Skip to content

Commit

Permalink
Merge pull request #18 from zitadel/storage
Browse files Browse the repository at this point in the history
fix(saml): update dependencies and integration for ZITADEL V2
  • Loading branch information
livio-a authored Aug 11, 2022
2 parents a8d4f83 + cc1a2bd commit 2a3e2f3
Show file tree
Hide file tree
Showing 40 changed files with 1,410 additions and 885 deletions.
10 changes: 1 addition & 9 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,17 @@ require (
github.com/russellhaering/goxmldsig v1.2.0
github.com/stretchr/testify v1.7.1
github.com/zitadel/logging v0.3.3
github.com/zitadel/oidc v1.3.1
gopkg.in/square/go-jose.v2 v2.6.0
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/felixge/httpsnoop v1.0.1 // indirect
github.com/golang/protobuf v1.4.2 // indirect
github.com/gorilla/schema v1.2.0 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/google/go-cmp v0.5.2 // indirect
github.com/jonboulle/clockwork v0.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 // indirect
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 // indirect
golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 // indirect
golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect
golang.org/x/text v0.3.7 // indirect
google.golang.org/appengine v1.6.6 // indirect
google.golang.org/protobuf v1.25.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)
348 changes: 0 additions & 348 deletions go.sum

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions pkg/http/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package http

import (
"encoding/json"
"net/http"
"reflect"
)

func MarshalJSON(w http.ResponseWriter, i interface{}) {
MarshalJSONWithStatus(w, i, http.StatusOK)
}

func MarshalJSONWithStatus(w http.ResponseWriter, i interface{}, status int) {
w.Header().Set("content-type", "application/json")
w.WriteHeader(status)
if i == nil || (reflect.ValueOf(i).Kind() == reflect.Ptr && reflect.ValueOf(i).IsNil()) {
return
}
err := json.NewEncoder(w).Encode(i)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
30 changes: 16 additions & 14 deletions pkg/provider/attribute_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package provider

import (
"fmt"
"io/ioutil"
"net/http"

"github.com/zitadel/logging"

"github.com/zitadel/saml/pkg/provider/checker"
"github.com/zitadel/saml/pkg/provider/serviceprovider"
"github.com/zitadel/saml/pkg/provider/xml"
Expand All @@ -11,8 +15,6 @@ import (
"github.com/zitadel/saml/pkg/provider/xml/samlp"
"github.com/zitadel/saml/pkg/provider/xml/soap"
"github.com/zitadel/saml/pkg/provider/xml/xml_dsig"
"io/ioutil"
"net/http"
)

func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *http.Request) {
Expand All @@ -23,6 +25,14 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
var attrQuery *samlp.AttributeQueryType
var response *samlp.ResponseType

metadata, _, err := p.GetMetadata(r.Context())
if err != nil {
err := fmt.Errorf("failed to read idp metadata: %w", err)
logging.Error(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

//parse body to string
checkerInstance.WithLogicStep(
func() error {
Expand All @@ -33,7 +43,6 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
attrQueryRequest = string(b)
return nil
},
"SAML-ap2j3n1",
func() {
http.Error(w, fmt.Errorf("failed to parse body: %w", err).Error(), http.StatusInternalServerError)
},
Expand All @@ -48,7 +57,6 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
}
return nil
},
"SAML-qpoin2a",
func() {
http.Error(w, fmt.Errorf("failed to decode request: %w", err).Error(), http.StatusInternalServerError)
},
Expand All @@ -63,7 +71,6 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
}
return nil
},
" SAML-asdi1n",
func() {
http.Error(w, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), http.StatusInternalServerError)
},
Expand All @@ -79,7 +86,6 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
func() *xml_dsig.SignatureType { return attrQuery.Signature },
func() *md.EntityDescriptorType { return sp.Metadata },
),
"SAML-bxi3n5",
func() {
http.Error(w, fmt.Errorf("failed to validate certificate from request: %w", err).Error(), http.StatusInternalServerError)
},
Expand All @@ -95,16 +101,14 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
func() *serviceprovider.ServiceProvider { return sp },
func(errF error) { err = errF },
),
"SAML-ao1n2ps",
func() {
http.Error(w, fmt.Errorf("failed to extract signature from request: %w", err).Error(), http.StatusInternalServerError)
},
)

// verify that destination in request is this IDP
checkerInstance.WithLogicStep(
func() error { err = p.verifyRequestDestinationOfAttrQuery(attrQuery); return err },
"SAML-ap2n1a",
func() error { err = verifyRequestDestinationOfAttrQuery(metadata, attrQuery); return err },
func() {
http.Error(w, fmt.Errorf("failed to verify request destination: %w", err).Error(), http.StatusInternalServerError)
},
Expand All @@ -124,10 +128,9 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
queriedAttrs = append(queriedAttrs, queriedAttr)
}
}
response = makeAttributeQueryResponse(attrQuery.Id, p.EntityID, sp.GetEntityID(), attrs, queriedAttrs)
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs)
return nil
},
"SAML-wosm22",
func() {
http.Error(w, fmt.Errorf("failed to get userinfo: %w", err).Error(), http.StatusInternalServerError)
},
Expand All @@ -136,9 +139,8 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
// create enveloped signature
checkerInstance.WithLogicStep(
func() error {
return createPostSignature(response, p)
return createPostSignature(r.Context(), response, p)
},
"SAML-p012sa",
func() {
http.Error(w, fmt.Errorf("failed to sign response: %w", err).Error(), http.StatusInternalServerError)
},
Expand All @@ -156,7 +158,7 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
}

if err := xml.WriteXMLMarshalled(w, soapResponse); err != nil {
logging.Log("SAML-91j12bk").Error(err)
logging.Error(err)
http.Error(w, fmt.Errorf("failed to send response: %w", err).Error(), http.StatusInternalServerError)
}
}
28 changes: 14 additions & 14 deletions pkg/provider/checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ func (c *Checker) CheckFailed() bool {
return false
}

func (c *Checker) WithValueNotEmptyCheck(valueName string, value func() string, errorLogID string, errorFunc func()) *Checker {
func (c *Checker) WithValueNotEmptyCheck(valueName string, value func() string, errorFunc func()) *Checker {
c.addStep(func() bool {
if value() == "" {
logging.Log(errorLogID).Errorf("empty value %s", valueName)
logging.Errorf("empty value %s", valueName)
errorFunc()
return true
}
Expand All @@ -34,11 +34,11 @@ func (c *Checker) WithValueNotEmptyCheck(valueName string, value func() string,
return c
}

func (c *Checker) WithValuesNotEmptyCheck(values func() []string, errorLogID string, errorFunc func()) *Checker {
func (c *Checker) WithValuesNotEmptyCheck(values func() []string, errorFunc func()) *Checker {
c.addStep(func() bool {
for _, value := range values() {
if value == "" {
logging.Log(errorLogID).Errorf("empty value")
logging.Errorf("empty value")
errorFunc()
return true
}
Expand All @@ -48,10 +48,10 @@ func (c *Checker) WithValuesNotEmptyCheck(values func() []string, errorLogID str
return c
}

func (c *Checker) WithValueLengthCheck(valueName string, value func() string, minlength, maxlength int, errorLogID string, errorFunc func()) *Checker {
func (c *Checker) WithValueLengthCheck(valueName string, value func() string, minlength, maxlength int, errorFunc func()) *Checker {
c.addStep(func() bool {
if (minlength > 0 && len(value()) < minlength) || (maxlength > 0 && len(value()) > maxlength) {
logging.Log(errorLogID).Errorf("error with value length %s", valueName)
logging.Errorf("error with value length %s", valueName)
errorFunc()
return true
}
Expand All @@ -62,10 +62,10 @@ func (c *Checker) WithValueLengthCheck(valueName string, value func() string, mi
return c
}

func (c *Checker) WithValueEqualsCheck(valueName string, value func() string, equal func() string, errorLogID string, errorFunc func()) *Checker {
func (c *Checker) WithValueEqualsCheck(valueName string, value func() string, equal func() string, errorFunc func()) *Checker {
c.addStep(func() bool {
if value() != equal() {
logging.Log(errorLogID).Errorf("value not equal %s: %s, %s", valueName, value(), equal())
logging.Errorf("value not equal %s: %s, %s", valueName, value(), equal())
errorFunc()
return true
}
Expand All @@ -76,11 +76,11 @@ func (c *Checker) WithValueEqualsCheck(valueName string, value func() string, eq
return c
}

func (c *Checker) WithConditionalValueNotEmpty(cond func() bool, valueName string, value func() string, errorLogID string, errorFunc func()) *Checker {
func (c *Checker) WithConditionalValueNotEmpty(cond func() bool, valueName string, value func() string, errorFunc func()) *Checker {
c.addStep(func() bool {
if cond() {
if value() == "" {
logging.Log(errorLogID).Errorf("empty value %s", valueName)
logging.Errorf("empty value %s", valueName)
errorFunc()
return true
}
Expand All @@ -91,11 +91,11 @@ func (c *Checker) WithConditionalValueNotEmpty(cond func() bool, valueName strin
return c
}

func (c *Checker) WithConditionalLogicStep(cond func() bool, logic func() error, errorLogID string, errorFunc func()) *Checker {
func (c *Checker) WithConditionalLogicStep(cond func() bool, logic func() error, errorFunc func()) *Checker {
c.addStep(func() bool {
if cond() {
if err := logic(); err != nil {
logging.Log(errorLogID).Error(err)
logging.Error(err)
errorFunc()
return true
}
Expand All @@ -106,10 +106,10 @@ func (c *Checker) WithConditionalLogicStep(cond func() bool, logic func() error,
return c
}

func (c *Checker) WithLogicStep(logic func() error, errorLogID string, errorFunc func()) *Checker {
func (c *Checker) WithLogicStep(logic func() error, errorFunc func()) *Checker {
c.addStep(func() bool {
if err := logic(); err != nil {
logging.Log(errorLogID).Error(err)
logging.Error(err)
errorFunc()
return true
}
Expand Down
11 changes: 3 additions & 8 deletions pkg/provider/checker/checker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package checker_test

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"

"github.com/zitadel/saml/pkg/provider/checker"
"testing"
)

func Test_CheckerEmpty(t *testing.T) {
Expand Down Expand Up @@ -48,7 +50,6 @@ func Test_CheckerWithValueNotEmptyCheck(t *testing.T) {
checkerInstance.WithValueNotEmptyCheck(
"test",
func() string { return tt.arg },
"test",
errorFunc,
)

Expand Down Expand Up @@ -108,7 +109,6 @@ func Test_CheckerWithValuesNotEmptyCheck(t *testing.T) {
checkerInstance := &checker.Checker{}
checkerInstance.WithValuesNotEmptyCheck(
func() []string { return tt.arg },
"test",
errorFunc,
)

Expand Down Expand Up @@ -206,7 +206,6 @@ func Test_CheckerWithValueLengthCheck(t *testing.T) {
func() string { return tt.args.value },
tt.args.minlength,
tt.args.maxlength,
"test",
errorFunc,
)

Expand Down Expand Up @@ -267,7 +266,6 @@ func Test_CheckerWithValueEqualsCheck(t *testing.T) {
"test",
func() string { return tt.args.value },
func() string { return tt.args.equals },
"test",
errorFunc,
)

Expand Down Expand Up @@ -364,7 +362,6 @@ func Test_CheckerWithConditionalValueNotEmpty(t *testing.T) {
func() bool { return tt.args.condition },
"test",
func() string { return tt.args.value },
"test",
errorFunc,
)

Expand Down Expand Up @@ -400,7 +397,6 @@ func Test_CheckerWithLogicStep(t *testing.T) {
checkerInstance := &checker.Checker{}
checkerInstance.WithLogicStep(
tt.arg,
"test",
errorFunc,
)

Expand Down Expand Up @@ -463,7 +459,6 @@ func Test_CheckerWithConditionalLogicStep(t *testing.T) {
checkerInstance.WithConditionalLogicStep(
tt.args.cond,
tt.args.f,
"test",
errorFunc,
)

Expand Down
Loading

0 comments on commit 2a3e2f3

Please sign in to comment.