Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into auth-request-interface
Browse files Browse the repository at this point in the history
  • Loading branch information
stebenz committed Dec 18, 2024
2 parents 5043c73 + 95c785b commit 6468498
Show file tree
Hide file tree
Showing 19 changed files with 301 additions and 412 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-22.04
strategy:
matrix:
go: ['1.19', '1.20', '1.21']
go: ['1.19', '1.20', '1.21', '1.22', '1.23']
name: Go ${{ matrix.go }} test
steps:
- uses: actions/checkout@v4
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ Versions that also build are marked with :warning:.
|---------|--------------------|
| <1.19 | :x: |
| 1.19 | :warning: |
| 1.20 | :white_check_mark: |
| 1.21 | :white_check_mark: |
| 1.20 | :warning: |
| 1.21 | :warning: |
| 1.22 | :white_check_mark: |
| 1.23 | :white_check_mark: |

## Why another library

Expand Down
8 changes: 6 additions & 2 deletions pkg/provider/attribute_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
queriedAttrs = append(queriedAttrs, queriedAttr)
}
}
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.timeFormat)
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.TimeFormat, p.Expiration)
return nil
},
func() {
Expand All @@ -139,7 +139,11 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
// create enveloped signature
checkerInstance.WithLogicStep(
func() error {
return createPostSignature(r.Context(), response, p)
cert, key, err := getResponseCert(r.Context(), p.storage)
if err != nil {
return err
}
return createPostSignature(response, key, cert, p.conf.SignatureAlgorithm)
},
func() {
http.Error(w, fmt.Errorf("failed to sign response: %w", err).Error(), http.StatusInternalServerError)
Expand Down
41 changes: 25 additions & 16 deletions pkg/provider/identityprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type MetadataIDPConfig struct {
type IdentityProviderConfig struct {
MetadataIDPConfig *MetadataIDPConfig

PostTemplate *template.Template
LogoutTemplate *template.Template

SignatureAlgorithm string
DigestAlgorithm string
EncryptionAlgorithm string
Expand Down Expand Up @@ -68,7 +71,8 @@ type IdentityProvider struct {
metadataEndpoint *Endpoint
endpoints *Endpoints

timeFormat string
TimeFormat string
Expiration time.Duration
}

type Endpoints struct {
Expand All @@ -79,25 +83,30 @@ type Endpoints struct {
attributeEndpoint Endpoint
}

func NewIdentityProvider(metadata Endpoint, conf *IdentityProviderConfig, storage IDPStorage) (*IdentityProvider, error) {
postTemplate, err := template.New("post").Parse(postTemplate)
if err != nil {
return nil, err
}

logoutTemplate, err := template.New("logout").Parse(logoutTemplate)
if err != nil {
return nil, err
}

func NewIdentityProvider(metadata Endpoint, conf *IdentityProviderConfig, storage IDPStorage) (_ *IdentityProvider, err error) {
idp := &IdentityProvider{
storage: storage,
metadataEndpoint: &metadata,
conf: conf,
postTemplate: postTemplate,
logoutTemplate: logoutTemplate,
postTemplate: conf.PostTemplate,
logoutTemplate: conf.LogoutTemplate,
endpoints: endpointConfigToEndpoints(conf.Endpoints),
timeFormat: DefaultTimeFormat,
TimeFormat: DefaultTimeFormat,
Expiration: DefaultExpiration,
}

if conf.PostTemplate == nil {
idp.postTemplate, err = template.New("post").Parse(postTemplate)
if err != nil {
return nil, err
}
}

if conf.LogoutTemplate == nil {
idp.logoutTemplate, err = template.New("logout").Parse(logoutTemplate)
if err != nil {
return nil, err
}
}

if conf.MetadataIDPConfig == nil {
Expand Down Expand Up @@ -153,7 +162,7 @@ func (p *IdentityProvider) GetMetadata(ctx context.Context) (*md.IDPSSODescripto
return nil, nil, err
}

metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.timeFormat)
metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.TimeFormat)
return metadata, aaMetadata, nil
}

Expand Down
59 changes: 34 additions & 25 deletions pkg/provider/login.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package provider

import (
"context"
"fmt"
"net/http"

"github.com/zitadel/logging"

"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/saml/pkg/provider/xml/samlp"
)

func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Request) {
Expand All @@ -16,7 +20,6 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
Issuer: p.GetEntityID(r.Context()),
}

ctx := r.Context()
if err := r.ParseForm(); err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to parse form: %w", err).Error(), http.StatusInternalServerError)
Expand All @@ -34,52 +37,58 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
authRequest, err := p.storage.AuthRequestByID(r.Context(), requestID)
if err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to get request: %w", err).Error(), p.timeFormat))
response.sendBackResponse(r, w, p.errorResponse(response, StatusCodeRequestDenied, fmt.Errorf("failed to get request: %w", err).Error()))
return
}
response.RequestID = authRequest.GetAuthRequestID()
response.RelayState = authRequest.GetRelayState()
response.ProtocolBinding = authRequest.GetBindingType()
response.AcsUrl = authRequest.GetAccessConsumerServiceURL()

if !authRequest.Done() {
entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID())
if err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError)
return
}
response.Audience = entityID

entityID, err := p.storage.GetEntityIDByAppID(r.Context(), authRequest.GetApplicationID())
samlResponse, err := p.loginResponse(r.Context(), authRequest, response)
if err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get entityID: %w", err).Error(), http.StatusInternalServerError)
response.sendBackResponse(r, w, response.makeFailedResponse(err.Error(), "failed to create response", p.TimeFormat))
return
}
response.Audience = entityID

response.sendBackResponse(r, w, samlResponse)
return
}

func (p *IdentityProvider) loginResponse(ctx context.Context, authRequest models.AuthRequestInt, response *Response) (*samlp.ResponseType, error) {
if !authRequest.Done() {
logging.Error(StatusCodeAuthNFailed)
return nil, fmt.Errorf(StatusCodeAuthNFailed)
}

attrs := &Attributes{}
if err := p.storage.SetUserinfoWithUserID(ctx, authRequest.GetApplicationID(), attrs, authRequest.GetUserID(), []int{}); err != nil {
logging.Error(err)
http.Error(w, fmt.Errorf("failed to get userinfo: %w", err).Error(), http.StatusInternalServerError)
return
return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
}

samlResponse := response.makeSuccessfulResponse(attrs, p.timeFormat)
cert, key, err := getResponseCert(ctx, p.storage)
if err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeInvalidAttrNameOrValue)
}

switch response.ProtocolBinding {
case PostBinding:
if err := createPostSignature(r.Context(), samlResponse, p); err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
return
}
case RedirectBinding:
if err := createRedirectSignature(r.Context(), samlResponse, p, response); err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
return
}
samlResponse := response.makeSuccessfulResponse(attrs, p.TimeFormat, p.Expiration)
if err := createSignature(response, samlResponse, key, cert, p.conf.SignatureAlgorithm); err != nil {
logging.Error(err)
return nil, fmt.Errorf(StatusCodeResponder)
}
return samlResponse, nil
}

response.sendBackResponse(r, w, samlResponse)
return
func (p *IdentityProvider) errorResponse(response *Response, reason string, description string) *samlp.ResponseType {
return response.makeFailedResponse(reason, description, p.TimeFormat)
}
33 changes: 19 additions & 14 deletions pkg/provider/login_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package provider

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/golang/mock/gomock"
Expand All @@ -23,9 +23,11 @@ func TestSSO_loginHandleFunc(t *testing.T) {
Done bool
}
type res struct {
code int
err bool
state string
code int
err bool
state string
inflate bool
b64 bool
}
type sp struct {
appID string
Expand Down Expand Up @@ -235,7 +237,7 @@ func TestSSO_loginHandleFunc(t *testing.T) {
ID: "test",
AuthRequestID: "test",
Binding: RedirectBinding,
AcsURL: "url",
AcsURL: "https://sp.example.com",
RelayState: "relaystate",
UserID: "userid",
Done: false,
Expand All @@ -247,9 +249,11 @@ func TestSSO_loginHandleFunc(t *testing.T) {
},
},
res{
code: 500,
state: "",
err: false,
code: 302,
state: StatusCodeAuthNFailed,
err: false,
inflate: true,
b64: true,
}},
}

Expand Down Expand Up @@ -297,14 +301,15 @@ func TestSSO_loginHandleFunc(t *testing.T) {
defer func() {
_ = res.Body.Close()
}()
response, err := ioutil.ReadAll(res.Body)
if res.StatusCode != tt.res.code {
t.Errorf("ssoHandleFunc() code got = %v, want %v", res.StatusCode, tt.res)
return
}

// currently only checked for redirect binding
if tt.res.state != "" {
if err := parseForState(string(response), tt.res.state); err != nil {
responseURL, err := url.Parse(res.Header.Get("Location"))
if err != nil {
t.Errorf("error while parsing url")
}

if err := parseForState(tt.res.inflate, tt.res.b64, responseURL.Query().Get("SAMLResponse"), tt.res.state); err != nil {
t.Errorf("ssoHandleFunc() response state not: %v", tt.res.state)
return
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/provider/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to parse form: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to parse form: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to decode request: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -69,10 +69,10 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
checkIfRequestTimeIsStillValid(
func() string { return logoutRequest.IssueInstant },
func() string { return logoutRequest.NotOnOrAfter },
p.timeFormat,
p.TimeFormat,
),
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to validate request: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -83,7 +83,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return err
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.timeFormat))
response.sendBackLogoutResponse(w, response.makeFailedLogoutResponse(StatusCodeRequestDenied, fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.TimeFormat))
},
)

Expand All @@ -106,7 +106,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque

response.sendBackLogoutResponse(
w,
response.makeSuccessfulLogoutResponse(p.timeFormat),
response.makeSuccessfulLogoutResponse(p.TimeFormat),
)
logging.Info(fmt.Sprintf("logout request for user %s", logoutRequest.NameID.Text))
}
Expand Down
Loading

0 comments on commit 6468498

Please sign in to comment.