diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index f8628605..56219610 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -343,7 +343,7 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client // authCodeURLOptions contains options for AuthCodeURL type authCodeURLOptions struct { - claims, loginHint, tenantID, domainHint string + claims, loginHint, tenantID, domainHint, state string } // AuthCodeURLOption is implemented by options for AuthCodeURL @@ -353,7 +353,7 @@ type AuthCodeURLOption interface { // AuthCodeURL creates a URL used to acquire an authorization code. Users need to call CreateAuthorizationCodeURLParameters and pass it in. // -// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID] +// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithState] func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) { o := authCodeURLOptions{} if err := options.ApplyOptions(&o, opts); err != nil { @@ -363,12 +363,36 @@ func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, if err != nil { return "", err } + ap.State = o.state ap.Claims = o.claims ap.LoginHint = o.loginHint ap.DomainHint = o.domainHint return cca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap) } +// WithState adds a user-generated state to the request. +func WithState(state string) interface { + AuthCodeURLOption + options.CallOption +} { + return struct { + AuthCodeURLOption + options.CallOption + }{ + CallOption: options.NewCallOption( + func(a any) error { + switch t := a.(type) { + case *authCodeURLOptions: + t.state = state + default: + return fmt.Errorf("unexpected options type %T", a) + } + return nil + }, + ), + } +} + // WithLoginHint pre-populates the login prompt with a username. func WithLoginHint(username string) interface { AuthCodeURLOption diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 28bad83e..1d977c54 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -1256,6 +1256,46 @@ func TestWithLoginHint(t *testing.T) { } } +func TestWithState(t *testing.T) { + state := "abc-123-secure-string" + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) + if err != nil { + t.Fatal(err) + } + client.base.Token.Resolver = &fake.ResolveEndpoints{} + for _, expectState := range []bool{true, false} { + t.Run(fmt.Sprint(expectState), func(t *testing.T) { + opts := []AuthCodeURLOption{} + if expectState { + opts = append(opts, WithState(state)) + } + u, err := client.AuthCodeURL(context.Background(), "id", localhost, tokenScope, opts...) + if err != nil { + t.Fatal(err) + } + parsed, err := url.Parse(u) + if err != nil { + t.Fatal(err) + } + if !parsed.Query().Has("state") { + if !expectState { + return + } + t.Fatal("expected a state") + } else if !expectState { + t.Fatal("expected no state") + } + if actual := parsed.Query()["state"]; len(actual) != 1 || actual[0] != state { + t.Fatalf(`unexpected state "%v"`, actual) + } + }) + } +} + func TestWithDomainHint(t *testing.T) { domain := "contoso.com" cred, err := NewCredFromSecret(fakeSecret)