From 2021ef5da8f79dd3544e65fad432dca03023ca41 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 9 Jan 2025 09:38:18 +0000 Subject: [PATCH 01/29] Added support for force refresh_in --- apps/confidential/confidential_test.go | 56 ++++++++++++++++++ apps/internal/base/base.go | 6 +- apps/internal/base/internal/storage/items.go | 4 +- .../base/internal/storage/items_test.go | 3 + .../internal/storage/partitioned_storage.go | 1 + .../storage/partitioned_storage_test.go | 2 + .../internal/base/internal/storage/storage.go | 1 + .../base/internal/storage/storage_test.go | 18 ++++-- apps/internal/mock/mock.go | 17 ++++++ .../ops/accesstokens/accesstokens_test.go | 59 +++++++++++++++++++ .../internal/oauth/ops/accesstokens/tokens.go | 27 +++++++++ 11 files changed, 186 insertions(+), 8 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 1036c4f7..86cde3b2 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -812,6 +812,62 @@ func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.Rep return nil } +func TestAcquireTokenSilentRefreshIn(t *testing.T) { + + for _, test := range []struct { + expireOn int + refreshIn int + }{ + {3600, 1}, + {7200, 3600}, + } { + cache := make(testCache) + accessToken := "*" + lmo := "login.microsoftonline.com" + tenantA, tenantB := "a", "b" + authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB) + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", test.expireOn, test.refreshIn))) + + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + client, err := New(authorityA, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + // The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test + ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } + account := ar.Account + if actual := account.Realm; actual != tenantA { + t.Fatalf(`unexpected realm "%s"`, actual) + } + + // a client configured for a different tenant should be able to authenticate silently with the shared cache's data + client, err = New(authorityB, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + // this should succeed because the cache contains an access token from tenantA + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA))) + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA)) + if err != nil && test.refreshIn > 1 { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } + } +} + func TestWithCache(t *testing.T) { cache := make(testCache) accessToken := "*" diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index e473d126..b32d4c99 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -86,6 +86,7 @@ type AuthResult struct { Account shared.Account IDToken accesstokens.IDToken AccessToken string + RefreshIn time.Time ExpiresOn time.Time GrantedScopes []string DeclinedScopes []string @@ -128,6 +129,7 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu Account: account, IDToken: idToken, AccessToken: accessToken, + RefreshIn: storageTokenResponse.AccessToken.RefreshIn.T, ExpiresOn: storageTokenResponse.AccessToken.ExpiresOn.T, GrantedScopes: grantedScopes, DeclinedScopes: nil, @@ -346,7 +348,9 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) - return ar, err + if ar.RefreshIn.After(time.Now()) { + return ar, err + } } } diff --git a/apps/internal/base/internal/storage/items.go b/apps/internal/base/internal/storage/items.go index f9be9027..33b62c6b 100644 --- a/apps/internal/base/internal/storage/items.go +++ b/apps/internal/base/internal/storage/items.go @@ -72,6 +72,7 @@ type AccessToken struct { ClientID string `json:"client_id,omitempty"` Secret string `json:"secret,omitempty"` Scopes string `json:"target,omitempty"` + RefreshIn internalTime.Unix `json:"refresh_in,omitempty"` ExpiresOn internalTime.Unix `json:"expires_on,omitempty"` ExtendedExpiresOn internalTime.Unix `json:"extended_expires_on,omitempty"` CachedAt internalTime.Unix `json:"cached_at,omitempty"` @@ -83,7 +84,7 @@ type AccessToken struct { } // NewAccessToken is the constructor for AccessToken. -func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken { +func NewAccessToken(homeID, env, realm, clientID string, cachedAt, refreshIn, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken { return AccessToken{ HomeAccountID: homeID, Environment: env, @@ -93,6 +94,7 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, ex Secret: token, Scopes: scopes, CachedAt: internalTime.Unix{T: cachedAt.UTC()}, + RefreshIn: internalTime.Unix{T: refreshIn.UTC()}, ExpiresOn: internalTime.Unix{T: expiresOn.UTC()}, ExtendedExpiresOn: internalTime.Unix{T: extendedExpiresOn.UTC()}, TokenType: tokenType, diff --git a/apps/internal/base/internal/storage/items_test.go b/apps/internal/base/internal/storage/items_test.go index d1df933d..a5373452 100644 --- a/apps/internal/base/internal/storage/items_test.go +++ b/apps/internal/base/internal/storage/items_test.go @@ -56,7 +56,9 @@ var ( ) func TestCreateAccessToken(t *testing.T) { + testExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) + testRefreshIn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) testExtExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) testCachedAt := time.Date(2020, time.June, 13, 11, 0, 0, 0, time.UTC) actualAt := NewAccessToken("testHID", @@ -64,6 +66,7 @@ func TestCreateAccessToken(t *testing.T) { "realm", "clientID", testCachedAt, + testRefreshIn, testExpiresOn, testExtExpiresOn, "user.read", diff --git a/apps/internal/base/internal/storage/partitioned_storage.go b/apps/internal/base/internal/storage/partitioned_storage.go index c0931833..1ffc928e 100644 --- a/apps/internal/base/internal/storage/partitioned_storage.go +++ b/apps/internal/base/internal/storage/partitioned_storage.go @@ -114,6 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes realm, clientID, cachedAt, + tokenResponse.RefreshIn.T, tokenResponse.ExpiresOn.T, tokenResponse.ExtExpiresOn.T, target, diff --git a/apps/internal/base/internal/storage/partitioned_storage_test.go b/apps/internal/base/internal/storage/partitioned_storage_test.go index 86859cf2..e90d94c5 100644 --- a/apps/internal/base/internal/storage/partitioned_storage_test.go +++ b/apps/internal/base/internal/storage/partitioned_storage_test.go @@ -155,6 +155,7 @@ func TestReadPartitionedAccessToken(t *testing.T) { now, now, now, + now, "openid user.read", "secret", "Bearer", @@ -211,6 +212,7 @@ func TestWritePartitionedAccessToken(t *testing.T) { now, now, now, + now, "openid", "secret", "tokenType", diff --git a/apps/internal/base/internal/storage/storage.go b/apps/internal/base/internal/storage/storage.go index 2221e60c..bc34f433 100644 --- a/apps/internal/base/internal/storage/storage.go +++ b/apps/internal/base/internal/storage/storage.go @@ -193,6 +193,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces realm, clientID, cachedAt, + tokenResponse.RefreshIn.T, tokenResponse.ExpiresOn.T, tokenResponse.ExtExpiresOn.T, target, diff --git a/apps/internal/base/internal/storage/storage_test.go b/apps/internal/base/internal/storage/storage_test.go index 0570115c..0cc8218b 100644 --- a/apps/internal/base/internal/storage/storage_test.go +++ b/apps/internal/base/internal/storage/storage_test.go @@ -227,6 +227,7 @@ func TestReadAccessToken(t *testing.T) { now, now, now, + now, "openid user.read", "secret", "tokenType", @@ -241,6 +242,7 @@ func TestReadAccessToken(t *testing.T) { now, now, now, + now, "openid user.read", "secret2", "", @@ -343,6 +345,7 @@ func TestWriteAccessToken(t *testing.T) { now, now, now, + now, "openid", "secret", "tokenType", @@ -848,6 +851,7 @@ func TestIsAccessTokenValid(t *testing.T) { cachedAt := time.Now() badCachedAt := time.Now().Add(500 * time.Second) expiresOn := time.Now().Add(1000 * time.Second) + refreshIn := time.Now().Add(1000 * time.Second) badExpiresOn := time.Now().Add(200 * time.Second) extended := time.Now() @@ -858,16 +862,16 @@ func TestIsAccessTokenValid(t *testing.T) { }{ { desc: "Success", - token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, expiresOn, extended, "openid", "secret", "tokenType", ""), + token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, refreshIn, expiresOn, extended, "openid", "secret", "tokenType", ""), }, { desc: "ExpiresOnUnixTimestamp has expired", - token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, badExpiresOn, extended, "openid", "secret", "tokenType", ""), + token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, refreshIn, badExpiresOn, extended, "openid", "secret", "tokenType", ""), err: true, }, { desc: "Success", - token: NewAccessToken("hid", "env", "realm", "cid", badCachedAt, expiresOn, extended, "openid", "secret", "tokenType", ""), + token: NewAccessToken("hid", "env", "realm", "cid", badCachedAt, refreshIn, expiresOn, extended, "openid", "secret", "tokenType", ""), err: true, }, } @@ -890,6 +894,7 @@ func TestRead(t *testing.T) { "realm", "cid", time.Now(), + time.Now(), time.Now().Add(1000*time.Second), time.Now(), "openid profile", @@ -1039,6 +1044,7 @@ func TestWrite(t *testing.T) { "realm", "cid", now, + now, now.Add(1000*time.Second), now, "openid profile", @@ -1136,7 +1142,7 @@ func TestRemoveRefreshTokens(t *testing.T) { func TestRemoveAccessTokens(t *testing.T) { now := time.Now() storageManager := newForTest(nil) - testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid", "secret", "tokenType", "") + testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid", "secret", "tokenType", "") key := testAccessToken.Key() contract := &Contract{ AccessTokens: map[string]AccessToken{ @@ -1187,7 +1193,7 @@ func TestRemoveAccountObject(t *testing.T) { func TestRemoveAccount(t *testing.T) { now := time.Now() - testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid profile", "secret", "tokenType", "") + testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid profile", "secret", "tokenType", "") testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret") testAppMeta := NewAppMetaData("fid", "cid", "env") testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid") @@ -1229,7 +1235,7 @@ func TestRemoveAccount(t *testing.T) { func TestRemoveEmptyAccount(t *testing.T) { now := time.Now() - testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid profile", "secret", "tokenType", "") + testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid profile", "secret", "tokenType", "") testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret") testAppMeta := NewAppMetaData("fid", "cid", "env") testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid") diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 5de171fd..80c8fe81 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -94,6 +94,23 @@ func GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo string, e return []byte(body) } +func GetAccessTokenBodyWithRefreshIn(accessToken, idToken, refreshToken, clientInfo string, expiresIn int, refreshIn int) []byte { + body := fmt.Sprintf( + `{"access_token": "%s","expires_in": %d,"refresh_in":%d ,"expires_on": %d,"token_type": "Bearer"`, + accessToken, expiresIn, refreshIn, time.Now().Add(time.Duration(expiresIn)*time.Second).Unix(), + ) + if clientInfo != "" { + body += fmt.Sprintf(`, "client_info": "%s"`, clientInfo) + } + if idToken != "" { + body += fmt.Sprintf(`, "id_token": "%s"`, idToken) + } + if refreshToken != "" { + body += fmt.Sprintf(`, "refresh_token": "%s"`, refreshToken) + } + body += "}" + return []byte(body) +} func GetIDToken(tenant, issuer string) string { now := time.Now().Unix() payload := []byte(fmt.Sprintf(`{"aud": "%s","exp": %d,"iat": %d,"iss": "%s","tid": "%s"}`, tenant, now+3600, now, issuer, tenant)) diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go index 59d3506d..7d9018a5 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go @@ -778,6 +778,50 @@ func TestTokenResponseUnmarshal(t *testing.T) { }, jwtDecoder: jwtDecoderFake, }, + { + desc: "Success", + payload: ` + { + "access_token": "secret", + "expires_in": 3600, + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, + want: TokenResponse{ + AccessToken: "secret", + ExpiresOn: internalTime.DurationTime{T: time.Unix(3600, 0)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, + { + desc: "Success", + payload: ` + { + "access_token": "secret", + "expires_in": 36000, + "ext_expires_in": 86399, + "client_info": {"uid": "uid","utid": "utid"}, + "scope": "openid profile" + }`, + want: TokenResponse{ + AccessToken: "secret", + ExpiresOn: internalTime.DurationTime{T: time.Unix(36000, 0)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, + ClientInfo: ClientInfo{ + UID: "uid", + UTID: "utid", + }, + }, + jwtDecoder: jwtDecoderFake, + }, } for _, test := range tests { @@ -795,6 +839,21 @@ func TestTokenResponseUnmarshal(t *testing.T) { case err != nil: continue } + now := time.Now() + timeRemaining := got.ExpiresOn.T.Sub(now) + if got.ExpiresOn.T.Before(time.Now().Add(time.Hour * 2)) { + expectedRefreshIn := now.Add(timeRemaining) + const tolerance = 100 * time.Millisecond + if got.RefreshIn.T.Sub(expectedRefreshIn) > tolerance { + t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshIn.T, expectedRefreshIn) + } + } else { + expectedRefreshIn := now.Add(timeRemaining / 2) + const tolerance = 100 * time.Millisecond + if got.RefreshIn.T.Sub(expectedRefreshIn) > tolerance { + t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshIn.T, expectedRefreshIn) + } + } // Note: IncludeUnexported prevents minor differences in time.Time due to internal fields. if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" { diff --git a/apps/internal/oauth/ops/accesstokens/tokens.go b/apps/internal/oauth/ops/accesstokens/tokens.go index 3107b45c..76c895f7 100644 --- a/apps/internal/oauth/ops/accesstokens/tokens.go +++ b/apps/internal/oauth/ops/accesstokens/tokens.go @@ -173,6 +173,7 @@ type TokenResponse struct { FamilyID string `json:"foci"` IDToken IDToken `json:"id_token"` ClientInfo ClientInfo `json:"client_info"` + RefreshIn internalTime.DurationTime `json:"refresh_in"` ExpiresOn internalTime.DurationTime `json:"expires_in"` ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"` GrantedScopes Scopes `json:"scope"` @@ -183,6 +184,32 @@ type TokenResponse struct { scopesComputed bool } +func (t *TokenResponse) UnmarshalJSON(data []byte) error { + type Alias TokenResponse + aux := &struct { + *Alias + }{ + Alias: (*Alias)(t), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + // If refresh_in is not set, compute it as (now - expires_on) / 2 if expires_on > 2 hours from now + if t.RefreshIn.T.IsZero() && !t.ExpiresOn.T.IsZero() { + now := time.Now() + timeRemaining := t.ExpiresOn.T.Sub(now) + if timeRemaining > 2*time.Hour { + t.RefreshIn = internalTime.DurationTime{T: now.Add(timeRemaining / 2)} + } else { + t.RefreshIn = internalTime.DurationTime{T: t.ExpiresOn.T} + } + } + + return nil +} + // ComputeScope computes the final scopes based on what was granted by the server and // what our AuthParams were from the authority server. Per OAuth spec, if no scopes are returned, the response should be treated as if all scopes were granted // This behavior can be observed in client assertion flows, but can happen at any time, this check ensures we treat From 8d7a65bba803f405575c422780d9e2a95ab82c8f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 9 Jan 2025 10:17:15 +0000 Subject: [PATCH 02/29] fixed failing test --- apps/confidential/confidential_test.go | 2 ++ apps/internal/base/base.go | 2 +- apps/internal/base/internal/storage/storage_test.go | 5 ++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 86cde3b2..7401269d 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -138,6 +138,7 @@ func TestAcquireTokenByCredential(t *testing.T) { } client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, + RefreshIn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, @@ -463,6 +464,7 @@ func TestADFSTokenCaching(t *testing.T) { AccessToken: "at1", RefreshToken: "rt", TokenType: "bearer", + RefreshIn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index b32d4c99..9d208e89 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -348,7 +348,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) - if ar.RefreshIn.After(time.Now()) { + if ar.RefreshIn.IsZero() || ar.RefreshIn.After(time.Now()) { return ar, err } } diff --git a/apps/internal/base/internal/storage/storage_test.go b/apps/internal/base/internal/storage/storage_test.go index 0cc8218b..f8e13334 100644 --- a/apps/internal/base/internal/storage/storage_test.go +++ b/apps/internal/base/internal/storage/storage_test.go @@ -1013,6 +1013,8 @@ func TestWrite(t *testing.T) { PreferredUsername: "username", } expiresOn := internalTime.DurationTime{T: now.Add(1000 * time.Second)} + timeRemaining := expiresOn.T.Sub(now) / 2 + refreshIn := internalTime.DurationTime{T: now.Add(timeRemaining)} tokenResponse := accesstokens.TokenResponse{ AccessToken: "accessToken", RefreshToken: "refreshToken", @@ -1020,6 +1022,7 @@ func TestWrite(t *testing.T) { FamilyID: "fid", ClientInfo: clientInfo, GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}}, + RefreshIn: refreshIn, ExpiresOn: expiresOn, ExtExpiresOn: internalTime.DurationTime{T: now}, TokenType: "Bearer", @@ -1044,7 +1047,7 @@ func TestWrite(t *testing.T) { "realm", "cid", now, - now, + now.Add(500*time.Second), now.Add(1000*time.Second), now, "openid profile", From 566cdd20111d4b9a14179e1b3621f1d757f0c2dd Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 13 Jan 2025 11:23:11 +0000 Subject: [PATCH 03/29] updating storage --- apps/internal/base/base.go | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 9d208e89..981d121c 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -86,7 +86,7 @@ type AuthResult struct { Account shared.Account IDToken accesstokens.IDToken AccessToken string - RefreshIn time.Time + RefreshOn time.Time ExpiresOn time.Time GrantedScopes []string DeclinedScopes []string @@ -125,11 +125,21 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu return AuthResult{}, fmt.Errorf("problem decoding JWT token: %w", err) } } + + var refreshIn time.Time + if !storageTokenResponse.AccessToken.ExpiresOn.T.IsZero() { + now := time.Now() + timeRemaining := storageTokenResponse.AccessToken.ExpiresOn.T.Sub(now) + if timeRemaining > 2*time.Hour { + refreshIn = now.Add(timeRemaining / 2) + } + } + return AuthResult{ Account: account, IDToken: idToken, AccessToken: accessToken, - RefreshIn: storageTokenResponse.AccessToken.RefreshIn.T, + RefreshOn: refreshIn, ExpiresOn: storageTokenResponse.AccessToken.ExpiresOn.T, GrantedScopes: grantedScopes, DeclinedScopes: nil, @@ -348,7 +358,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) - if ar.RefreshIn.IsZero() || ar.RefreshIn.After(time.Now()) { + if ar.RefreshOn.After(time.Now()) { return ar, err } } @@ -366,7 +376,16 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if err != nil { return ar, err } - return b.AuthResultFromToken(ctx, authParams, token, true) + + result, err := b.AuthResultFromToken(ctx, authParams, token, true) + if err != nil { + if ar.Metadata.TokenSource == Cache && ar.RefreshOn.IsZero() { + return ar, nil + } else { + return result, err + } + } + return result, nil } func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) { From 78eeace3b66ed32b84f036e233037646fd2b8c8b Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 14 Jan 2025 21:43:20 +0000 Subject: [PATCH 04/29] Updated the force refresh in --- apps/confidential/confidential_test.go | 134 ++++++++++-------- apps/internal/base/base.go | 63 ++++---- apps/internal/base/base_test.go | 43 ++++++ apps/internal/base/internal/storage/items.go | 6 +- .../base/internal/storage/items_test.go | 4 +- .../internal/storage/partitioned_storage.go | 2 +- .../internal/base/internal/storage/storage.go | 2 +- .../base/internal/storage/storage_test.go | 30 ++-- apps/internal/exported/exported.go | 2 + apps/internal/oauth/oauth.go | 16 ++- .../ops/accesstokens/accesstokens_test.go | 8 +- .../internal/oauth/ops/accesstokens/tokens.go | 28 +--- 12 files changed, 191 insertions(+), 147 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 7401269d..f509d68a 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -138,7 +138,7 @@ func TestAcquireTokenByCredential(t *testing.T) { } client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, - RefreshIn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + RefreshOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, @@ -256,7 +256,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) { // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token, "", "rt", "", 3600, 7400))) client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { @@ -279,7 +279,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) { } // new assertion should trigger new token request token2 := token + "2" - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token2, "", "rt", "", 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token2, "", "rt", "", 3600, 360))) tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion+"2", tokenScope) if err != nil { t.Fatal(err) @@ -464,7 +464,6 @@ func TestADFSTokenCaching(t *testing.T) { AccessToken: "at1", RefreshToken: "rt", TokenType: "bearer", - RefreshIn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, @@ -710,7 +709,8 @@ func TestNewCredFromCertError(t *testing.T) { func TestNewCredFromTokenProvider(t *testing.T) { expectedToken := "expected token" called := false - expiresIn := 4200 + expiresIn := 18000 + refreshOn := expiresIn / 2 key := struct{}{} ctx := context.WithValue(context.Background(), key, true) cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) { @@ -730,6 +730,7 @@ func TestNewCredFromTokenProvider(t *testing.T) { return exported.TokenProviderResult{ AccessToken: expectedToken, ExpiresInSeconds: expiresIn, + RefreshInSeconds: refreshOn, }, nil }) client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) @@ -746,6 +747,9 @@ func TestNewCredFromTokenProvider(t *testing.T) { if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn { t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v) } + if v := int(time.Until(ar.RefreshOn).Seconds()); v < refreshOn-2 || v > refreshOn { + t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshOn, v) + } if ar.AccessToken != expectedToken { t.Fatalf(`unexpected token "%s"`, ar.AccessToken) } @@ -758,6 +762,70 @@ func TestNewCredFromTokenProvider(t *testing.T) { } } +func TestRefreshIn(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + firstToken := "first token" + secondToken := "new token" + refreshIn := 8200 + lmo := "login.microsoftonline.com" + tenant := "tenant" + + for _, needsRefresh := range []bool{false, true} { + name := "token doesn't need refresh" + if needsRefresh { + name = "token needs refresh" + refreshIn = 83 + } + t.Run(name, func(t *testing.T) { + // Create a mock client and append mock responses + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":14400,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, refreshIn))), + ) + // mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(firstToken, "", "", "", 14400))) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":14400,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, refreshIn))), + ) + // mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(secondToken, "", "", "", 14400))) + + // Create the client instance + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) + if err != nil { + t.Fatal(err) + } + + // Acquire the first token + ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope) + if err != nil { + t.Fatal(err) + } + + // Assert the first token is returned + if ar.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) + } + if ar.RefreshOn.IsZero() { + t.Fatal("RefreshOn shouldn't be zero") + } + if v := int(time.Until(ar.RefreshOn).Seconds()); v < refreshIn-102 || v > refreshIn { + t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshIn, v) + } + + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope) + if err != nil { + t.Fatal(err) + } + if (ar.AccessToken == secondToken) != needsRefresh { + t.Fatalf("wanted %q, got %q", secondToken, ar.AccessToken) + } + }) + } +} + func TestNewCredFromTokenProviderError(t *testing.T) { expectedError := "something went wrong" cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) { @@ -814,62 +882,6 @@ func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.Rep return nil } -func TestAcquireTokenSilentRefreshIn(t *testing.T) { - - for _, test := range []struct { - expireOn int - refreshIn int - }{ - {3600, 1}, - {7200, 3600}, - } { - cache := make(testCache) - accessToken := "*" - lmo := "login.microsoftonline.com" - tenantA, tenantB := "a", "b" - authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB) - mockClient := mock.Client{} - mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", test.expireOn, test.refreshIn))) - - cred, err := NewCredFromSecret(fakeSecret) - if err != nil { - t.Fatal(err) - } - client, err := New(authorityA, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) - if err != nil { - t.Fatal(err) - } - // The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test - ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) - if err != nil { - t.Fatal(err) - } - if ar.AccessToken != accessToken { - t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) - } - account := ar.Account - if actual := account.Realm; actual != tenantA { - t.Fatalf(`unexpected realm "%s"`, actual) - } - - // a client configured for a different tenant should be able to authenticate silently with the shared cache's data - client, err = New(authorityB, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient)) - if err != nil { - t.Fatal(err) - } - // this should succeed because the cache contains an access token from tenantA - mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA))) - ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA)) - if err != nil && test.refreshIn > 1 { - t.Fatal(err) - } - if ar.AccessToken != accessToken { - t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) - } - } -} - func TestWithCache(t *testing.T) { cache := make(testCache) accessToken := "*" diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 981d121c..cab8f196 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -5,8 +5,8 @@ package base import ( "context" - "errors" "fmt" + "net/http" "net/url" "reflect" "strings" @@ -14,6 +14,7 @@ import ( "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" @@ -83,9 +84,10 @@ type AcquireTokenOnBehalfOfParameters struct { // AuthResult contains the results of one token acquisition operation in PublicClientApplication // or ConfidentialClientApplication. For details see https://aka.ms/msal-net-authenticationresult type AuthResult struct { - Account shared.Account - IDToken accesstokens.IDToken - AccessToken string + Account shared.Account + IDToken accesstokens.IDToken + AccessToken string + //RefreshOn indicates the recommended time to request a new access token, or zero if no refresh time is suggested RefreshOn time.Time ExpiresOn time.Time GrantedScopes []string @@ -125,21 +127,11 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu return AuthResult{}, fmt.Errorf("problem decoding JWT token: %w", err) } } - - var refreshIn time.Time - if !storageTokenResponse.AccessToken.ExpiresOn.T.IsZero() { - now := time.Now() - timeRemaining := storageTokenResponse.AccessToken.ExpiresOn.T.Sub(now) - if timeRemaining > 2*time.Hour { - refreshIn = now.Add(timeRemaining / 2) - } - } - return AuthResult{ Account: account, IDToken: idToken, AccessToken: accessToken, - RefreshOn: refreshIn, + RefreshOn: storageTokenResponse.AccessToken.RefreshOn.T, ExpiresOn: storageTokenResponse.AccessToken.ExpiresOn.T, GrantedScopes: grantedScopes, DeclinedScopes: nil, @@ -158,6 +150,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco Account: account, IDToken: tokenResponse.IDToken, AccessToken: tokenResponse.AccessToken, + RefreshOn: tokenResponse.RefreshOn.T, ExpiresOn: tokenResponse.ExpiresOn.T, GrantedScopes: tokenResponse.GrantedScopes.Slice, Metadata: AuthResultMetadata{ @@ -357,10 +350,28 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if silent.Claims == "" { ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { - ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) - if ar.RefreshOn.After(time.Now()) { - return ar, err + if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { + return b.AuthResultFromToken(ctx, authParams, tr, true) + } else if callErr, ok := er.(*errors.CallErr); ok { + // Check if the error is of type CallErr and matches the relevant status codes + switch callErr.Resp.StatusCode { + case http.StatusRequestTimeout, // 408 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout: // 504 + default: + // Handle non-retryable errors + return AuthResult{}, er + } + } } + ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) + return ar, err } } @@ -376,16 +387,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if err != nil { return ar, err } - - result, err := b.AuthResultFromToken(ctx, authParams, token, true) - if err != nil { - if ar.Metadata.TokenSource == Cache && ar.RefreshOn.IsZero() { - return ar, nil - } else { - return result, err - } - } - return result, nil + return b.AuthResultFromToken(ctx, authParams, token, true) } func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) { @@ -481,6 +483,11 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au return ar, err } +// shouldRefresh returns true if the token should be refreshed. +func shouldRefresh(t time.Time) bool { + return !t.IsZero() && t.Before(time.Now().Add(2*time.Hour)) +} + func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) { if b.cacheAccessor != nil { b.cacheAccessorMu.RLock() diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 09238780..7414ecbd 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -444,3 +444,46 @@ func TestAuthResultFromStorage(t *testing.T) { } } } + +func TestShouldRefresh(t *testing.T) { + // Get current time + now := time.Now() + + // Test cases + tests := []struct { + name string + input time.Time + expected bool + }{ + { + name: "Zero time", + input: time.Time{}, + expected: false, + }, + { + name: "More than 2 hours ago", + input: now.Add(3 * time.Hour).Add(time.Second), + expected: false, + }, + { + name: "Exactly 2 hours ago", + input: now.Add(2 * time.Hour).Add(time.Second), + expected: false, + }, + { + name: "Less than 2 hours ago", + input: now.Add(1 * time.Hour).Add(time.Second), + expected: true, + }, + } + + // Run the test cases + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := shouldRefresh(tt.input) + if actual != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, actual) + } + }) + } +} diff --git a/apps/internal/base/internal/storage/items.go b/apps/internal/base/internal/storage/items.go index 33b62c6b..6a0499ce 100644 --- a/apps/internal/base/internal/storage/items.go +++ b/apps/internal/base/internal/storage/items.go @@ -72,7 +72,7 @@ type AccessToken struct { ClientID string `json:"client_id,omitempty"` Secret string `json:"secret,omitempty"` Scopes string `json:"target,omitempty"` - RefreshIn internalTime.Unix `json:"refresh_in,omitempty"` + RefreshOn internalTime.Unix `json:"refresh_on,omitempty"` ExpiresOn internalTime.Unix `json:"expires_on,omitempty"` ExtendedExpiresOn internalTime.Unix `json:"extended_expires_on,omitempty"` CachedAt internalTime.Unix `json:"cached_at,omitempty"` @@ -84,7 +84,7 @@ type AccessToken struct { } // NewAccessToken is the constructor for AccessToken. -func NewAccessToken(homeID, env, realm, clientID string, cachedAt, refreshIn, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken { +func NewAccessToken(homeID, env, realm, clientID string, cachedAt, refreshOn, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken { return AccessToken{ HomeAccountID: homeID, Environment: env, @@ -94,7 +94,7 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, refreshIn, ex Secret: token, Scopes: scopes, CachedAt: internalTime.Unix{T: cachedAt.UTC()}, - RefreshIn: internalTime.Unix{T: refreshIn.UTC()}, + RefreshOn: internalTime.Unix{T: refreshOn.UTC()}, ExpiresOn: internalTime.Unix{T: expiresOn.UTC()}, ExtendedExpiresOn: internalTime.Unix{T: extendedExpiresOn.UTC()}, TokenType: tokenType, diff --git a/apps/internal/base/internal/storage/items_test.go b/apps/internal/base/internal/storage/items_test.go index a5373452..3febaa7d 100644 --- a/apps/internal/base/internal/storage/items_test.go +++ b/apps/internal/base/internal/storage/items_test.go @@ -58,7 +58,7 @@ var ( func TestCreateAccessToken(t *testing.T) { testExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) - testRefreshIn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) + testRefreshOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) testExtExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC) testCachedAt := time.Date(2020, time.June, 13, 11, 0, 0, 0, time.UTC) actualAt := NewAccessToken("testHID", @@ -66,7 +66,7 @@ func TestCreateAccessToken(t *testing.T) { "realm", "clientID", testCachedAt, - testRefreshIn, + testRefreshOn, testExpiresOn, testExtExpiresOn, "user.read", diff --git a/apps/internal/base/internal/storage/partitioned_storage.go b/apps/internal/base/internal/storage/partitioned_storage.go index 1ffc928e..5b468a3f 100644 --- a/apps/internal/base/internal/storage/partitioned_storage.go +++ b/apps/internal/base/internal/storage/partitioned_storage.go @@ -114,7 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes realm, clientID, cachedAt, - tokenResponse.RefreshIn.T, + tokenResponse.RefreshOn.T, tokenResponse.ExpiresOn.T, tokenResponse.ExtExpiresOn.T, target, diff --git a/apps/internal/base/internal/storage/storage.go b/apps/internal/base/internal/storage/storage.go index bc34f433..80a0fe13 100644 --- a/apps/internal/base/internal/storage/storage.go +++ b/apps/internal/base/internal/storage/storage.go @@ -193,7 +193,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces realm, clientID, cachedAt, - tokenResponse.RefreshIn.T, + tokenResponse.RefreshOn.T, tokenResponse.ExpiresOn.T, tokenResponse.ExtExpiresOn.T, target, diff --git a/apps/internal/base/internal/storage/storage_test.go b/apps/internal/base/internal/storage/storage_test.go index f8e13334..96591df1 100644 --- a/apps/internal/base/internal/storage/storage_test.go +++ b/apps/internal/base/internal/storage/storage_test.go @@ -851,7 +851,7 @@ func TestIsAccessTokenValid(t *testing.T) { cachedAt := time.Now() badCachedAt := time.Now().Add(500 * time.Second) expiresOn := time.Now().Add(1000 * time.Second) - refreshIn := time.Now().Add(1000 * time.Second) + refreshOn := time.Now().Add(1000 * time.Second) badExpiresOn := time.Now().Add(200 * time.Second) extended := time.Now() @@ -862,16 +862,16 @@ func TestIsAccessTokenValid(t *testing.T) { }{ { desc: "Success", - token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, refreshIn, expiresOn, extended, "openid", "secret", "tokenType", ""), + token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, refreshOn, expiresOn, extended, "openid", "secret", "tokenType", ""), }, { desc: "ExpiresOnUnixTimestamp has expired", - token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, refreshIn, badExpiresOn, extended, "openid", "secret", "tokenType", ""), + token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, refreshOn, badExpiresOn, extended, "openid", "secret", "tokenType", ""), err: true, }, { desc: "Success", - token: NewAccessToken("hid", "env", "realm", "cid", badCachedAt, refreshIn, expiresOn, extended, "openid", "secret", "tokenType", ""), + token: NewAccessToken("hid", "env", "realm", "cid", badCachedAt, refreshOn, expiresOn, extended, "openid", "secret", "tokenType", ""), err: true, }, } @@ -888,15 +888,16 @@ func TestIsAccessTokenValid(t *testing.T) { } func TestRead(t *testing.T) { + now := time.Now() accessTokenCacheItem := NewAccessToken( "hid", "env", "realm", "cid", - time.Now(), - time.Now(), - time.Now().Add(1000*time.Second), - time.Now(), + now, + now.Add(500*time.Second), + now.Add(1000*time.Second), + now, "openid profile", "secret", "Bearer", @@ -1013,8 +1014,7 @@ func TestWrite(t *testing.T) { PreferredUsername: "username", } expiresOn := internalTime.DurationTime{T: now.Add(1000 * time.Second)} - timeRemaining := expiresOn.T.Sub(now) / 2 - refreshIn := internalTime.DurationTime{T: now.Add(timeRemaining)} + refreshOn := internalTime.DurationTime{T: now.Add(500 * time.Second)} tokenResponse := accesstokens.TokenResponse{ AccessToken: "accessToken", RefreshToken: "refreshToken", @@ -1022,7 +1022,7 @@ func TestWrite(t *testing.T) { FamilyID: "fid", ClientInfo: clientInfo, GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}}, - RefreshIn: refreshIn, + RefreshOn: refreshOn, ExpiresOn: expiresOn, ExtExpiresOn: internalTime.DurationTime{T: now}, TokenType: "Bearer", @@ -1047,7 +1047,7 @@ func TestWrite(t *testing.T) { "realm", "cid", now, - now.Add(500*time.Second), + refreshOn.T, now.Add(1000*time.Second), now, "openid profile", @@ -1145,7 +1145,7 @@ func TestRemoveRefreshTokens(t *testing.T) { func TestRemoveAccessTokens(t *testing.T) { now := time.Now() storageManager := newForTest(nil) - testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid", "secret", "tokenType", "") + testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, time.Time{}, now, now, "openid", "secret", "tokenType", "") key := testAccessToken.Key() contract := &Contract{ AccessTokens: map[string]AccessToken{ @@ -1196,7 +1196,7 @@ func TestRemoveAccountObject(t *testing.T) { func TestRemoveAccount(t *testing.T) { now := time.Now() - testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid profile", "secret", "tokenType", "") + testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, time.Time{}, now, now, "openid profile", "secret", "tokenType", "") testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret") testAppMeta := NewAppMetaData("fid", "cid", "env") testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid") @@ -1238,7 +1238,7 @@ func TestRemoveAccount(t *testing.T) { func TestRemoveEmptyAccount(t *testing.T) { now := time.Now() - testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, now, "openid profile", "secret", "tokenType", "") + testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, time.Time{}, now, now, "openid profile", "secret", "tokenType", "") testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret") testAppMeta := NewAppMetaData("fid", "cid", "env") testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid") diff --git a/apps/internal/exported/exported.go b/apps/internal/exported/exported.go index 7b673e3f..de1bf381 100644 --- a/apps/internal/exported/exported.go +++ b/apps/internal/exported/exported.go @@ -31,4 +31,6 @@ type TokenProviderResult struct { AccessToken string // ExpiresInSeconds is the lifetime of the token in seconds ExpiresInSeconds int + // RefreshInSeconds indicates the suggested time to refresh the token, if any + RefreshInSeconds int } diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index e0653134..3ec8df9c 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -111,7 +111,7 @@ func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams Scopes: scopes, TenantID: authParams.AuthorityInfo.Tenant, } - tr, err := cred.TokenProvider(ctx, params) + pr, err := cred.TokenProvider(ctx, params) if err != nil { if len(scopes) == 0 { err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err) @@ -119,14 +119,20 @@ func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams } return accesstokens.TokenResponse{}, err } - return accesstokens.TokenResponse{ + tr := accesstokens.TokenResponse{ TokenType: authParams.AuthnScheme.AccessTokenType(), - AccessToken: tr.AccessToken, + AccessToken: pr.AccessToken, ExpiresOn: internalTime.DurationTime{ - T: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second), + T: now.Add(time.Duration(pr.ExpiresInSeconds) * time.Second), }, GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes}, - }, nil + } + if pr.RefreshInSeconds > 0 { + tr.RefreshOn = internalTime.DurationTime{ + T: now.Add(time.Duration(pr.RefreshInSeconds) * time.Second), + } + } + return tr, nil } if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil { diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go index 7d9018a5..d3ffad5c 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go @@ -844,14 +844,14 @@ func TestTokenResponseUnmarshal(t *testing.T) { if got.ExpiresOn.T.Before(time.Now().Add(time.Hour * 2)) { expectedRefreshIn := now.Add(timeRemaining) const tolerance = 100 * time.Millisecond - if got.RefreshIn.T.Sub(expectedRefreshIn) > tolerance { - t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshIn.T, expectedRefreshIn) + if got.RefreshOn.T.Sub(expectedRefreshIn) > tolerance { + t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshOn.T, expectedRefreshIn) } } else { expectedRefreshIn := now.Add(timeRemaining / 2) const tolerance = 100 * time.Millisecond - if got.RefreshIn.T.Sub(expectedRefreshIn) > tolerance { - t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshIn.T, expectedRefreshIn) + if got.RefreshOn.T.Sub(expectedRefreshIn) > tolerance { + t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshOn.T, expectedRefreshIn) } } diff --git a/apps/internal/oauth/ops/accesstokens/tokens.go b/apps/internal/oauth/ops/accesstokens/tokens.go index 76c895f7..6ebb1bba 100644 --- a/apps/internal/oauth/ops/accesstokens/tokens.go +++ b/apps/internal/oauth/ops/accesstokens/tokens.go @@ -173,7 +173,7 @@ type TokenResponse struct { FamilyID string `json:"foci"` IDToken IDToken `json:"id_token"` ClientInfo ClientInfo `json:"client_info"` - RefreshIn internalTime.DurationTime `json:"refresh_in"` + RefreshOn internalTime.DurationTime `json:"refresh_in,omitempty"` ExpiresOn internalTime.DurationTime `json:"expires_in"` ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"` GrantedScopes Scopes `json:"scope"` @@ -184,32 +184,6 @@ type TokenResponse struct { scopesComputed bool } -func (t *TokenResponse) UnmarshalJSON(data []byte) error { - type Alias TokenResponse - aux := &struct { - *Alias - }{ - Alias: (*Alias)(t), - } - - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - - // If refresh_in is not set, compute it as (now - expires_on) / 2 if expires_on > 2 hours from now - if t.RefreshIn.T.IsZero() && !t.ExpiresOn.T.IsZero() { - now := time.Now() - timeRemaining := t.ExpiresOn.T.Sub(now) - if timeRemaining > 2*time.Hour { - t.RefreshIn = internalTime.DurationTime{T: now.Add(timeRemaining / 2)} - } else { - t.RefreshIn = internalTime.DurationTime{T: t.ExpiresOn.T} - } - } - - return nil -} - // ComputeScope computes the final scopes based on what was granted by the server and // what our AuthParams were from the authority server. Per OAuth spec, if no scopes are returned, the response should be treated as if all scopes were granted // This behavior can be observed in client assertion flows, but can happen at any time, this check ensures we treat From de241bc2446cb9319901ff3a4f4b384dad325320 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 15 Jan 2025 14:14:53 +0000 Subject: [PATCH 05/29] Updated test description --- apps/confidential/confidential_test.go | 2 -- apps/internal/oauth/ops/accesstokens/accesstokens_test.go | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index f509d68a..1c2f4774 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -786,11 +786,9 @@ func TestRefreshIn(t *testing.T) { mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":14400,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, refreshIn))), ) - // mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(firstToken, "", "", "", 14400))) mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":14400,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, refreshIn))), ) - // mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(secondToken, "", "", "", 14400))) // Create the client instance client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go index d3ffad5c..dc0eab43 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go @@ -757,7 +757,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { jwtDecoder: jwtDecoderFake, }, { - desc: "Success", + desc: "Success with same expires_in and ext_expires_in", payload: ` { "access_token": "secret", @@ -779,7 +779,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { jwtDecoder: jwtDecoderFake, }, { - desc: "Success", + desc: "Success with different expires_in and ext_expires_in", payload: ` { "access_token": "secret", @@ -801,7 +801,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { jwtDecoder: jwtDecoderFake, }, { - desc: "Success", + desc: "Success with refresh_in not provided in response", payload: ` { "access_token": "secret", From 4a38a84a282a24514ed9ba7c243c47d15f4edaeb Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 15 Jan 2025 14:36:03 +0000 Subject: [PATCH 06/29] Updated test --- .../ops/accesstokens/accesstokens_test.go | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go index dc0eab43..d41593a8 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go @@ -762,6 +762,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { { "access_token": "secret", "expires_in": 86399, + "refresh_in": 43199, "ext_expires_in": 86399, "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" @@ -769,6 +770,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { want: TokenResponse{ AccessToken: "secret", ExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + RefreshOn: internalTime.DurationTime{T: time.Unix(43199, 0)}, ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ @@ -779,11 +781,12 @@ func TestTokenResponseUnmarshal(t *testing.T) { jwtDecoder: jwtDecoderFake, }, { - desc: "Success with different expires_in and ext_expires_in", + desc: "Success with different expires_in and refresh On", payload: ` { "access_token": "secret", "expires_in": 3600, + "refresh_in": 43199, "ext_expires_in": 86399, "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" @@ -791,6 +794,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { want: TokenResponse{ AccessToken: "secret", ExpiresOn: internalTime.DurationTime{T: time.Unix(3600, 0)}, + RefreshOn: internalTime.DurationTime{T: time.Unix(43199, 0)}, ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ @@ -839,21 +843,6 @@ func TestTokenResponseUnmarshal(t *testing.T) { case err != nil: continue } - now := time.Now() - timeRemaining := got.ExpiresOn.T.Sub(now) - if got.ExpiresOn.T.Before(time.Now().Add(time.Hour * 2)) { - expectedRefreshIn := now.Add(timeRemaining) - const tolerance = 100 * time.Millisecond - if got.RefreshOn.T.Sub(expectedRefreshIn) > tolerance { - t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshOn.T, expectedRefreshIn) - } - } else { - expectedRefreshIn := now.Add(timeRemaining / 2) - const tolerance = 100 * time.Millisecond - if got.RefreshOn.T.Sub(expectedRefreshIn) > tolerance { - t.Errorf("Expected refresh_in to be half of expires_on, but got %v, expected %v", got.RefreshOn.T, expectedRefreshIn) - } - } // Note: IncludeUnexported prevents minor differences in time.Time due to internal fields. if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" { From 0102a9eb4e1433edc3a5435fa94ad606b0e902c3 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 16 Jan 2025 14:13:02 +0000 Subject: [PATCH 07/29] Updated force refresh in Shifted RefreshOn in Metadata --- apps/confidential/confidential_test.go | 58 ++++++++++++++++++-------- apps/internal/base/base.go | 14 ++++--- apps/internal/base/base_test.go | 32 ++++++-------- apps/internal/mock/mock.go | 6 +++ 4 files changed, 67 insertions(+), 43 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 1c2f4774..6afd69f9 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -25,6 +25,7 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" @@ -747,7 +748,7 @@ func TestNewCredFromTokenProvider(t *testing.T) { if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn { t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v) } - if v := int(time.Until(ar.RefreshOn).Seconds()); v < refreshOn-2 || v > refreshOn { + if v := int(time.Until(ar.Metadata.RefreshOn).Seconds()); v < refreshOn-2 || v > refreshOn { t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshOn, v) } if ar.AccessToken != expectedToken { @@ -769,55 +770,76 @@ func TestRefreshIn(t *testing.T) { } firstToken := "first token" secondToken := "new token" - refreshIn := 8200 lmo := "login.microsoftonline.com" tenant := "tenant" - - for _, needsRefresh := range []bool{false, true} { + refreshIn := 43200 + expiresIn := 86400 + for _, tt := range []struct { + shouldGetNewToken bool + secondRequestAfter int + shouldReturnError bool + }{ + {secondRequestAfter: 40000, shouldGetNewToken: false}, // from cache + {secondRequestAfter: 43400, shouldGetNewToken: true}, // refresh in expired so new token + {secondRequestAfter: 40000, shouldGetNewToken: false, shouldReturnError: true}, // refresh in not expired but refresh failed so new token + {secondRequestAfter: 80000, shouldGetNewToken: true, shouldReturnError: false}, // refresh in expired but refresh failed so new token + {secondRequestAfter: 1003400, shouldGetNewToken: true}, + } { name := "token doesn't need refresh" - if needsRefresh { - name = "token needs refresh" - refreshIn = 83 - } t.Run(name, func(t *testing.T) { + originalTime := base.GetCurrentTime + defer func() { + base.GetCurrentTime = originalTime + }() // Create a mock client and append mock responses mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":14400,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, refreshIn))), - ) - mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":14400,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, refreshIn))), + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), ) + if tt.shouldReturnError { + mockClient.AppendResponse( + mock.WithCode(http.StatusBadGateway), + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + } else { + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))), + ) + } // Create the client instance client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) if err != nil { t.Fatal(err) } - // Acquire the first token ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope) if err != nil { t.Fatal(err) } - // Assert the first token is returned if ar.AccessToken != firstToken { t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) } - if ar.RefreshOn.IsZero() { + if ar.Metadata.RefreshOn.IsZero() { t.Fatal("RefreshOn shouldn't be zero") } - if v := int(time.Until(ar.RefreshOn).Seconds()); v < refreshIn-102 || v > refreshIn { + if v := int(time.Until(ar.Metadata.RefreshOn).Seconds()); v < refreshIn-10 || v > refreshIn+10 { t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshIn, v) } - + fixedTime := time.Now().Add(time.Duration(tt.secondRequestAfter) * time.Second) + base.GetCurrentTime = func() time.Time { + return fixedTime + } ar, err = client.AcquireTokenSilent(context.Background(), tokenScope) if err != nil { t.Fatal(err) } - if (ar.AccessToken == secondToken) != needsRefresh { + if ar.Metadata.TokenSource != base.Cache && !tt.shouldGetNewToken { + t.Fatal("should have returned from cache.") + } + if (ar.AccessToken == secondToken) != tt.shouldGetNewToken { t.Fatalf("wanted %q, got %q", secondToken, ar.AccessToken) } }) diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index cab8f196..9c49b162 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -88,7 +88,6 @@ type AuthResult struct { IDToken accesstokens.IDToken AccessToken string //RefreshOn indicates the recommended time to request a new access token, or zero if no refresh time is suggested - RefreshOn time.Time ExpiresOn time.Time GrantedScopes []string DeclinedScopes []string @@ -97,6 +96,7 @@ type AuthResult struct { // AuthResultMetadata which contains meta data for the AuthResult type AuthResultMetadata struct { + RefreshOn time.Time TokenSource TokenSource } @@ -131,12 +131,12 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu Account: account, IDToken: idToken, AccessToken: accessToken, - RefreshOn: storageTokenResponse.AccessToken.RefreshOn.T, ExpiresOn: storageTokenResponse.AccessToken.ExpiresOn.T, GrantedScopes: grantedScopes, DeclinedScopes: nil, Metadata: AuthResultMetadata{ TokenSource: Cache, + RefreshOn: storageTokenResponse.AccessToken.RefreshOn.T, }, }, nil } @@ -150,11 +150,11 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco Account: account, IDToken: tokenResponse.IDToken, AccessToken: tokenResponse.AccessToken, - RefreshOn: tokenResponse.RefreshOn.T, ExpiresOn: tokenResponse.ExpiresOn.T, GrantedScopes: tokenResponse.GrantedScopes.Slice, Metadata: AuthResultMetadata{ TokenSource: IdentityProvider, + RefreshOn: tokenResponse.RefreshOn.T, }, }, nil } @@ -351,8 +351,6 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { return b.AuthResultFromToken(ctx, authParams, tr, true) } else if callErr, ok := er.(*errors.CallErr); ok { @@ -483,9 +481,13 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au return ar, err } +// This function wraps time.Now() and is used for refreshing the application +// was created to test the function against refreshin +var GetCurrentTime = time.Now + // shouldRefresh returns true if the token should be refreshed. func shouldRefresh(t time.Time) bool { - return !t.IsZero() && t.Before(time.Now().Add(2*time.Hour)) + return !t.IsZero() && t.Before(GetCurrentTime()) } func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) { diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 7414ecbd..31acf52b 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -445,11 +445,11 @@ func TestAuthResultFromStorage(t *testing.T) { } } +// TestShouldRefresh tests the shouldRefresh function func TestShouldRefresh(t *testing.T) { - // Get current time + // Get the current time to use for comparison now := time.Now() - // Test cases tests := []struct { name string input time.Time @@ -457,32 +457,26 @@ func TestShouldRefresh(t *testing.T) { }{ { name: "Zero time", - input: time.Time{}, - expected: false, + input: time.Time{}, // Zero time + expected: false, // Should return false because it's zero time }, { - name: "More than 2 hours ago", - input: now.Add(3 * time.Hour).Add(time.Second), - expected: false, + name: "Future time", + input: now.Add(time.Hour), // 1 hour in the future + expected: false, // Should return false because it's in the future }, { - name: "Exactly 2 hours ago", - input: now.Add(2 * time.Hour).Add(time.Second), - expected: false, - }, - { - name: "Less than 2 hours ago", - input: now.Add(1 * time.Hour).Add(time.Second), - expected: true, + name: "Past time", + input: now.Add(-time.Hour), // 1 hour in the past + expected: true, // Should return true because it's in the past }, } - // Run the test cases for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - actual := shouldRefresh(tt.input) - if actual != tt.expected { - t.Errorf("expected %v, got %v", tt.expected, actual) + result := shouldRefresh(tt.input) + if result != tt.expected { + t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected) } }) } diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 80c8fe81..401c02df 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -46,6 +46,12 @@ func WithCallback(callback func(*http.Request)) responseOption { }) } +func WithCode(code int) responseOption { + return respOpt(func(r *response) { + r.code = code + }) +} + // Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence. type Client struct { resp []response From 10e7c6f435764fadc533400d7ac5fcf7e09a1a8d Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 23 Jan 2025 15:51:50 +0000 Subject: [PATCH 08/29] Updated some tests --- apps/confidential/confidential_test.go | 27 ++++++++++----- .../ops/accesstokens/accesstokens_test.go | 34 +++++++++---------- 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 6afd69f9..b0cf3991 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -139,9 +139,9 @@ func TestAcquireTokenByCredential(t *testing.T) { } client, err := fakeClient(accesstokens.TokenResponse{ AccessToken: token, - RefreshOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + RefreshOn: internalTime.DurationTime{T: time.Now().Add(6 * time.Hour)}, + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(12 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(12 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, TokenType: "Bearer", }, cred, fakeAuthority) @@ -745,12 +745,13 @@ func TestNewCredFromTokenProvider(t *testing.T) { if !called { t.Fatal("token provider wasn't invoked") } - if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn { - t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v) + if !isTimeSame(ar.ExpiresOn, expiresIn) { + t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, ar.ExpiresOn.Second()) } - if v := int(time.Until(ar.Metadata.RefreshOn).Seconds()); v < refreshOn-2 || v > refreshOn { - t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshOn, v) + if !isTimeSame(ar.Metadata.RefreshOn, refreshOn) { + t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshOn, ar.Metadata.RefreshOn.Second()) } + if ar.AccessToken != expectedToken { t.Fatalf(`unexpected token "%s"`, ar.AccessToken) } @@ -825,8 +826,8 @@ func TestRefreshIn(t *testing.T) { if ar.Metadata.RefreshOn.IsZero() { t.Fatal("RefreshOn shouldn't be zero") } - if v := int(time.Until(ar.Metadata.RefreshOn).Seconds()); v < refreshIn-10 || v > refreshIn+10 { - t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshIn, v) + if !isTimeSame(ar.Metadata.RefreshOn, refreshIn) { + t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshIn, ar.Metadata.RefreshOn.Second()) } fixedTime := time.Now().Add(time.Duration(tt.secondRequestAfter) * time.Second) base.GetCurrentTime = func() time.Time { @@ -846,6 +847,14 @@ func TestRefreshIn(t *testing.T) { } } +func isTimeSame(t time.Time, expectedSeconds int) bool { + v := int(time.Until(t).Seconds()) + if v < expectedSeconds-2 || v > expectedSeconds+2 { + return false + } + return true +} + func TestNewCredFromTokenProviderError(t *testing.T) { expectedError := "something went wrong" cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) { diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go index d41593a8..6a8ffe07 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go @@ -748,8 +748,8 @@ func TestTokenResponseUnmarshal(t *testing.T) { payload: ` { "access_token": "secret", - "expires_in": 86399, - "ext_expires_in": 86399, + "expires_in": 86400, + "ext_expires_in": 86400, "client_info": error, "scope": "openid profile" }`, @@ -761,17 +761,17 @@ func TestTokenResponseUnmarshal(t *testing.T) { payload: ` { "access_token": "secret", - "expires_in": 86399, - "refresh_in": 43199, - "ext_expires_in": 86399, + "expires_in": 86400, + "refresh_in": 43200, + "ext_expires_in": 86400, "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" }`, want: TokenResponse{ AccessToken: "secret", - ExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, - RefreshOn: internalTime.DurationTime{T: time.Unix(43199, 0)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + ExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, + RefreshOn: internalTime.DurationTime{T: time.Unix(43200, 0)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ UID: "uid", @@ -785,17 +785,17 @@ func TestTokenResponseUnmarshal(t *testing.T) { payload: ` { "access_token": "secret", - "expires_in": 3600, - "refresh_in": 43199, - "ext_expires_in": 86399, + "expires_in": 86400, + "refresh_in": 43200, + "ext_expires_in": 86400, "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" }`, want: TokenResponse{ AccessToken: "secret", - ExpiresOn: internalTime.DurationTime{T: time.Unix(3600, 0)}, + ExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, RefreshOn: internalTime.DurationTime{T: time.Unix(43199, 0)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ UID: "uid", @@ -809,15 +809,15 @@ func TestTokenResponseUnmarshal(t *testing.T) { payload: ` { "access_token": "secret", - "expires_in": 36000, - "ext_expires_in": 86399, + "expires_in": 86400, + "ext_expires_in": 86400, "client_info": {"uid": "uid","utid": "utid"}, "scope": "openid profile" }`, want: TokenResponse{ AccessToken: "secret", - ExpiresOn: internalTime.DurationTime{T: time.Unix(36000, 0)}, - ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)}, + ExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ UID: "uid", From e95ece3c143351e0d996e1833d4a43feb22f95b9 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 23 Jan 2025 15:57:56 +0000 Subject: [PATCH 09/29] Updated time. --- apps/confidential/confidential_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index b0cf3991..00c9fa02 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -257,7 +257,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) { // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token, "", "rt", "", 3600, 7400))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token, "", "rt", "", 86400, 43200))) client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { @@ -280,7 +280,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) { } // new assertion should trigger new token request token2 := token + "2" - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token2, "", "rt", "", 3600, 360))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token2, "", "rt", "", 86400, 43200))) tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion+"2", tokenScope) if err != nil { t.Fatal(err) From 91dd29f22f6ea80c7e484e37988676df21872e7b Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 29 Jan 2025 14:30:49 +0000 Subject: [PATCH 10/29] Cleaned up code reference with PR comments --- apps/confidential/confidential.go | 4 +- apps/confidential/confidential_test.go | 51 +++++++++++++++++----- apps/internal/base/base.go | 13 +++--- apps/internal/base/base_test.go | 3 +- apps/internal/mock/mock.go | 30 +++++-------- apps/public/public.go | 6 +-- apps/public/public_test.go | 26 +++++------ apps/tests/benchmarks/confidential.go | 2 +- apps/tests/performance/performance_test.go | 2 +- 9 files changed, 77 insertions(+), 60 deletions(-) diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 5b375794..1e076c81 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -637,7 +637,7 @@ func (cca Client) AcquireTokenByUsernamePassword(ctx context.Context, scopes []s if err != nil { return AuthResult{}, err } - return cca.base.AuthResultFromToken(ctx, authParams, token, true) + return cca.base.AuthResultFromToken(ctx, authParams, token) } // acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow. @@ -731,7 +731,7 @@ func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string, if err != nil { return AuthResult{}, err } - return cca.base.AuthResultFromToken(ctx, authParams, token, true) + return cca.base.AuthResultFromToken(ctx, authParams, token) } // acquireTokenOnBehalfOfOptions contains optional configuration for AcquireTokenOnBehalfOf diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 00c9fa02..31b7eb6d 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -257,7 +257,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) { // TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351 mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token, "", "rt", "", 86400, 43200))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 86400, 43200))) client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) if err != nil { @@ -280,7 +280,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) { } // new assertion should trigger new token request token2 := token + "2" - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBodyWithRefreshIn(token2, "", "rt", "", 86400, 43200))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token2, "", "rt", "", 86400, 43200))) tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion+"2", tokenScope) if err != nil { t.Fatal(err) @@ -434,7 +434,7 @@ func TestAcquireTokenSilentTenants(t *testing.T) { t.Fatal("silent auth should fail because the cache is empty") } mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(tenant, "", "", "", 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(tenant, "", "", "", 3600, 0))) if _, err := client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant)); err != nil { t.Fatal(err) } @@ -800,7 +800,7 @@ func TestRefreshIn(t *testing.T) { ) if tt.shouldReturnError { mockClient.AppendResponse( - mock.WithCode(http.StatusBadGateway), + mock.WithHTTPStatusCode(http.StatusBadGateway), mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), ) } else { @@ -870,6 +870,33 @@ func TestNewCredFromTokenProviderError(t *testing.T) { } } +func TestTokenProviderResultForRefreshIn(t *testing.T) { + accessToken, claims, tenant := "at", "claims", "tenant" + cred := NewCredFromTokenProvider(func(ctx context.Context, tpp TokenProviderParameters) (TokenProviderResult, error) { + if tpp.Claims != claims { + t.Fatalf(`unexpected claims "%s"`, tpp.Claims) + } + if tpp.TenantID != tenant { + t.Fatalf(`unexpected tenant "%s"`, tpp.TenantID) + } + return TokenProviderResult{AccessToken: accessToken, ExpiresInSeconds: 36000, RefreshInSeconds: 18000}, nil + }) + client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{})) + if err != nil { + t.Fatal(err) + } + ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithClaims(claims), WithTenantID(tenant)) + if err != nil { + t.Fatal(err) + } + if !isTimeSame(ar.Metadata.RefreshOn, 18000) { + t.Fatal("RefreshOn should be 18000 seconds from now") + } + if ar.AccessToken != accessToken { + t.Fatalf(`unexpected access token "%s"`, ar.AccessToken) + } +} + func TestTokenProviderOptions(t *testing.T) { accessToken, claims, tenant := "at", "claims", "tenant" cred := NewCredFromTokenProvider(func(ctx context.Context, tpp TokenProviderParameters) (TokenProviderResult, error) { @@ -919,7 +946,7 @@ func TestWithCache(t *testing.T) { authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB) mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600, 0))) cred, err := NewCredFromSecret(fakeSecret) if err != nil { @@ -1030,7 +1057,7 @@ func TestWithClaims(t *testing.T) { mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600, 0)), mock.WithCallback(func(r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatal(err) @@ -1094,7 +1121,7 @@ func TestWithClaims(t *testing.T) { // client has cached access and refresh tokens. When given claims, it should redeem a refresh token for a new access token. newToken := "new-access-token" mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600)), + mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600, 0)), mock.WithCallback(func(r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatal(err) @@ -1154,7 +1181,7 @@ func TestWithTenantID(t *testing.T) { mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant))) mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) client, err := New(test.authority, fakeClientID, cred, WithHTTPClient(&mockClient)) @@ -1211,7 +1238,7 @@ func TestWithTenantID(t *testing.T) { otherTenant := "not-" + test.tenant if method == "obo" { mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0))) if _, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(otherTenant)); err != nil { t.Fatal(err) } @@ -1249,7 +1276,7 @@ func TestWithTenantID(t *testing.T) { mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant)) mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "", "", 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "", "", 3600, 0)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) if i == 0 { @@ -1305,7 +1332,7 @@ func TestWithInstanceDiscovery(t *testing.T) { } mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(stackurl, tenant))) mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)), ) client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) if err != nil { @@ -1367,7 +1394,7 @@ func TestWithPortAuthority(t *testing.T) { mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant))) mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient)) diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 9c49b162..a3725cef 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -352,7 +352,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if err == nil { if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { - return b.AuthResultFromToken(ctx, authParams, tr, true) + return b.AuthResultFromToken(ctx, authParams, tr) } else if callErr, ok := er.(*errors.CallErr); ok { // Check if the error is of type CallErr and matches the relevant status codes switch callErr.Resp.StatusCode { @@ -385,7 +385,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if err != nil { return ar, err } - return b.AuthResultFromToken(ctx, authParams, token, true) + return b.AuthResultFromToken(ctx, authParams, token) } func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) { @@ -414,7 +414,7 @@ func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams Acqui return AuthResult{}, err } - return b.AuthResultFromToken(ctx, authParams, token, true) + return b.AuthResultFromToken(ctx, authParams, token) } // AcquireTokenOnBehalfOf acquires a security token for an app using middle tier apps access token. @@ -443,15 +443,12 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq authParams.UserAssertion = onBehalfOfParams.UserAssertion token, err := b.Token.OnBehalfOf(ctx, authParams, onBehalfOfParams.Credential) if err == nil { - ar, err = b.AuthResultFromToken(ctx, authParams, token, true) + ar, err = b.AuthResultFromToken(ctx, authParams, token) } return ar, err } -func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) { - if !cacheWrite { - return NewAuthResult(token, shared.Account{}) - } +func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse) (AuthResult, error) { var m manager = b.manager if authParams.AuthorizationType == authority.ATOnBehalfOf { m = b.pmanager diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 31acf52b..69becd23 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -251,7 +251,7 @@ func TestCacheIOErrors(t *testing.T) { if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } - _, actual = client.AuthResultFromToken(ctx, authority.AuthParams{AuthnScheme: &authority.BearerAuthenticationScheme{}}, accesstokens.TokenResponse{}, true) + _, actual = client.AuthResultFromToken(ctx, authority.AuthParams{AuthnScheme: &authority.BearerAuthenticationScheme{}}, accesstokens.TokenResponse{}) if !errors.Is(actual, expected) { t.Fatalf(`expected "%v", got "%v"`, expected, actual) } @@ -284,7 +284,6 @@ func TestCacheIOErrors(t *testing.T) { IDToken: fakeIDToken, RefreshToken: "rt", }, - true, ) if err != nil { t.Fatal(err) diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 401c02df..b357e438 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -46,7 +46,7 @@ func WithCallback(callback func(*http.Request)) responseOption { }) } -func WithCode(code int) responseOption { +func WithHTTPStatusCode(code int) responseOption { return respOpt(func(r *response) { r.code = code }) @@ -82,29 +82,19 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { // CloseIdleConnections implements the comm.HTTPClient interface func (*Client) CloseIdleConnections() {} -func GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo string, expiresIn int) []byte { +func GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo string, expiresIn, refreshIn int) []byte { + // Start building the body with the common fields body := fmt.Sprintf( `{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`, accessToken, expiresIn, time.Now().Add(time.Duration(expiresIn)*time.Second).Unix(), ) - if clientInfo != "" { - body += fmt.Sprintf(`, "client_info": "%s"`, clientInfo) - } - if idToken != "" { - body += fmt.Sprintf(`, "id_token": "%s"`, idToken) - } - if refreshToken != "" { - body += fmt.Sprintf(`, "refresh_token": "%s"`, refreshToken) + + // Conditionally add the "refresh_in" field if refreshIn is provided + if refreshIn > 0 { + body += fmt.Sprintf(`, "refresh_in": %d`, refreshIn) } - body += "}" - return []byte(body) -} -func GetAccessTokenBodyWithRefreshIn(accessToken, idToken, refreshToken, clientInfo string, expiresIn int, refreshIn int) []byte { - body := fmt.Sprintf( - `{"access_token": "%s","expires_in": %d,"refresh_in":%d ,"expires_on": %d,"token_type": "Bearer"`, - accessToken, expiresIn, refreshIn, time.Now().Add(time.Duration(expiresIn)*time.Second).Unix(), - ) + // Add the optional fields if they are provided if clientInfo != "" { body += fmt.Sprintf(`, "client_info": "%s"`, clientInfo) } @@ -114,9 +104,13 @@ func GetAccessTokenBodyWithRefreshIn(accessToken, idToken, refreshToken, clientI if refreshToken != "" { body += fmt.Sprintf(`, "refresh_token": "%s"`, refreshToken) } + + // Close the JSON string body += "}" + return []byte(body) } + func GetIDToken(tenant, issuer string) string { now := time.Now().Unix() payload := []byte(fmt.Sprintf(`{"aud": "%s","exp": %d,"iat": %d,"iss": "%s","tid": "%s"}`, tenant, now+3600, now, issuer, tenant)) diff --git a/apps/public/public.go b/apps/public/public.go index 392e5e43..c5110d23 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -387,7 +387,7 @@ func (pca Client) AcquireTokenByUsernamePassword(ctx context.Context, scopes []s if err != nil { return AuthResult{}, err } - return pca.base.AuthResultFromToken(ctx, authParams, token, true) + return pca.base.AuthResultFromToken(ctx, authParams, token) } type DeviceCodeResult = accesstokens.DeviceCodeResult @@ -412,7 +412,7 @@ func (d DeviceCode) AuthenticationResult(ctx context.Context) (AuthResult, error if err != nil { return AuthResult{}, err } - return d.client.base.AuthResultFromToken(ctx, d.authParams, token, true) + return d.client.base.AuthResultFromToken(ctx, d.authParams, token) } // acquireTokenByDeviceCodeOptions contains optional configuration for AcquireTokenByDeviceCode @@ -687,7 +687,7 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, return AuthResult{}, err } - return pca.base.AuthResultFromToken(ctx, authParams, token, true) + return pca.base.AuthResultFromToken(ctx, authParams, token) } type interactiveAuthResult struct { diff --git a/apps/public/public_test.go b/apps/public/public_test.go index c0fb9b33..b3a1fe62 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -90,7 +90,7 @@ func TestAcquireTokenSilentHomeTenantAliases(t *testing.T) { for _, alias := range []string{"common", "organizations"} { mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, alias))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 3600, 0))) mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, homeTenant))) client, err := New("client-id", WithAuthority(fmt.Sprintf(authorityFmt, lmo, alias)), WithHTTPClient(&mockClient)) if err != nil { @@ -130,7 +130,7 @@ func TestAcquireTokenSilentWithTenantID(t *testing.T) { mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`))) mockClient.AppendResponse(mock.WithBody( - mock.GetAccessTokenBody(tenant, mock.GetIDToken(tenant, fmt.Sprintf(authorityFmt, lmo, tenant)), "rt-"+tenant, clientInfo, 3600)), + mock.GetAccessTokenBody(tenant, mock.GetIDToken(tenant, fmt.Sprintf(authorityFmt, lmo, tenant)), "rt-"+tenant, clientInfo, 3600, 0)), ) ar, err := client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithTenantID(tenant)) if err != nil { @@ -232,7 +232,7 @@ func TestAcquireTokenWithTenantID(t *testing.T) { mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`))) } mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(test.tenant, test.authority), "rt", clientInfo, 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(test.tenant, test.authority), "rt", clientInfo, 3600, 0)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) client, err := New("client-id", WithAuthority(test.authority), WithHTTPClient(&mockClient)) @@ -290,7 +290,7 @@ func TestAcquireTokenWithTenantID(t *testing.T) { } otherTenant := "not-" + test.tenant mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, otherTenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(otherTenant, test.authority), "rt", clientInfo, 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(otherTenant, test.authority), "rt", clientInfo, 3600, 0))) if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account), WithTenantID("not-"+test.tenant)); err != nil { t.Fatal(err) } @@ -386,7 +386,7 @@ func TestWithInstanceDiscovery(t *testing.T) { mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`))) } mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenant, authority), "rt", clientInfo, 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenant, authority), "rt", clientInfo, 3600, 0)), ) client, err := New("client-id", WithAuthority(authority), WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) if err != nil { @@ -457,7 +457,7 @@ func TestWithCache(t *testing.T) { authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB) mockClient := mock.Client{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), refreshToken, clientInfo, 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), refreshToken, clientInfo, 3600, 0))) client, err := New("client-id", WithAuthority(authorityA), WithCache(&cache), WithHTTPClient(&mockClient)) if err != nil { @@ -505,7 +505,7 @@ func TestWithCache(t *testing.T) { // this should work because the cache contains a refresh token for the user accessToken2 := accessToken + "2" mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantB))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken2, mock.GetIDToken(tenantB, authorityB), refreshToken, clientInfo, 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken2, mock.GetIDToken(tenantB, authorityB), refreshToken, clientInfo, 3600, 0))) ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) if err != nil { t.Fatal(err) @@ -583,7 +583,7 @@ func TestWithClaims(t *testing.T) { mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Federated","cloud_audience_urn":".","cloud_instance_name":".","domain_name":".","federation_protocol":".","federation_metadata_url":"."}`))) } mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600, 0)), mock.WithCallback(func(r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatal(err) @@ -644,7 +644,7 @@ func TestWithClaims(t *testing.T) { // when given claims, AcquireTokenSilent should request a new access token instead of returning the cached one newToken := "new-access-token" mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600)), + mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600, 0)), mock.WithCallback(func(r *http.Request) { if err := r.ParseForm(); err != nil { t.Fatal(err) @@ -684,7 +684,7 @@ func TestWithPortAuthority(t *testing.T) { mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant))) mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant))) mockClient.AppendResponse( - mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenant, "issuer"), refreshToken, clientInfo, 3600)), + mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenant, "issuer"), refreshToken, clientInfo, 3600, 0)), mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }), ) client, err := New("client-id", WithAuthority(authority), WithHTTPClient(&mockClient)) @@ -873,7 +873,7 @@ func TestWithAuthenticationScheme(t *testing.T) { name: "interactive", responses: [][]byte{ mock.GetTenantDiscoveryBody(lmo, tenant), - mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600), + mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600, 0), }, }, { @@ -881,7 +881,7 @@ func TestWithAuthenticationScheme(t *testing.T) { responses: [][]byte{ mock.GetTenantDiscoveryBody(lmo, tenant), []byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`), - mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600), + mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600, 0), }, }, } { @@ -957,7 +957,7 @@ func getNewClientWithMockedResponses( if includeAcquireSilentResponses { // we will be testing the AcquireTokenSilent flow after the initial flow, so append the correct responses mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) - mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600, 0))) } client, err := New("client-id", WithAuthority(authority), WithHTTPClient(&mockClient)) diff --git a/apps/tests/benchmarks/confidential.go b/apps/tests/benchmarks/confidential.go index 802bafbb..3582682d 100644 --- a/apps/tests/benchmarks/confidential.go +++ b/apps/tests/benchmarks/confidential.go @@ -91,7 +91,7 @@ func populateTokenCache(client base.Client, params testParams) execTime { AccessToken: accessToken, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: []string{strconv.FormatInt(int64(i), 10)}}, - }, true) + }) if err != nil { panic(err) } diff --git a/apps/tests/performance/performance_test.go b/apps/tests/performance/performance_test.go index 66050b72..c42ceeb0 100644 --- a/apps/tests/performance/performance_test.go +++ b/apps/tests/performance/performance_test.go @@ -60,7 +60,7 @@ func populateCache(users int, tokens int, authParams authority.AuthParams, clien IDToken: accesstokens.IDToken{ RawToken: "x.e30", }, - }, true) + }) if err != nil { panic(err) } From bd448e83619d4ceb8700fe6c2f67d807183dc618 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 4 Feb 2025 12:37:32 +0000 Subject: [PATCH 11/29] Refactor code --- apps/confidential/confidential_test.go | 5 +--- apps/internal/base/base.go | 33 +++++++++++++------------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 31b7eb6d..1fc8d986 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -849,10 +849,7 @@ func TestRefreshIn(t *testing.T) { func isTimeSame(t time.Time, expectedSeconds int) bool { v := int(time.Until(t).Seconds()) - if v < expectedSeconds-2 || v > expectedSeconds+2 { - return false - } - return true + return !(v < expectedSeconds-2 || v > expectedSeconds+2) } func TestNewCredFromTokenProviderError(t *testing.T) { diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index a3725cef..e0ca58fe 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -84,10 +84,9 @@ type AcquireTokenOnBehalfOfParameters struct { // AuthResult contains the results of one token acquisition operation in PublicClientApplication // or ConfidentialClientApplication. For details see https://aka.ms/msal-net-authenticationresult type AuthResult struct { - Account shared.Account - IDToken accesstokens.IDToken - AccessToken string - //RefreshOn indicates the recommended time to request a new access token, or zero if no refresh time is suggested + Account shared.Account + IDToken accesstokens.IDToken + AccessToken string ExpiresOn time.Time GrantedScopes []string DeclinedScopes []string @@ -353,18 +352,20 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { return b.AuthResultFromToken(ctx, authParams, tr) - } else if callErr, ok := er.(*errors.CallErr); ok { - // Check if the error is of type CallErr and matches the relevant status codes - switch callErr.Resp.StatusCode { - case http.StatusRequestTimeout, // 408 - http.StatusTooManyRequests, // 429 - http.StatusInternalServerError, // 500 - http.StatusBadGateway, // 502 - http.StatusServiceUnavailable, // 503 - http.StatusGatewayTimeout: // 504 - default: - // Handle non-retryable errors - return AuthResult{}, er + } else { + var callErr *errors.CallErr + if errors.As(er, &callErr) { + switch callErr.Resp.StatusCode { + case http.StatusRequestTimeout, // 408 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout: // 504 + default: + // return empty token for non handable error + return AuthResult{}, er + } } } } From 90e946ad4b63731995b9674257f13c2cbba812e1 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 7 Feb 2025 18:31:56 +0000 Subject: [PATCH 12/29] Updated the refreshin system on per tenant base --- apps/internal/base/base.go | 44 ++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index e0ca58fe..162d701b 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -11,6 +11,7 @@ import ( "reflect" "strings" "sync" + "sync/atomic" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" @@ -169,6 +170,8 @@ type Client struct { AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). cacheAccessor cache.ExportReplace cacheAccessorMu *sync.RWMutex + canRefresh *atomic.Value + refreshMu *sync.Mutex } // Option is an optional argument to the New constructor. @@ -245,7 +248,10 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O cacheAccessorMu: &sync.RWMutex{}, manager: storage.New(token), pmanager: storage.NewPartitionedManager(token), + canRefresh: &atomic.Value{}, + refreshMu: &sync.Mutex{}, } + client.canRefresh.Store(make(map[string]struct{})) for _, o := range options { if err = o(&client); err != nil { break @@ -344,12 +350,12 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if err != nil { return ar, err } - // ignore cached access tokens when given claims if silent.Claims == "" { ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { - if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { + if b.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T, tenant) { + defer b.removeTenantFromCanRefresh(tenant) if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { return b.AuthResultFromToken(ctx, authParams, tr) } else { @@ -483,9 +489,39 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au // was created to test the function against refreshin var GetCurrentTime = time.Now +func (c Client) doesTenantExists(tenant string) bool { + c.refreshMu.Lock() + defer c.refreshMu.Unlock() + canrefreshMap := c.canRefresh.Load().(map[string]struct{}) + _, exists := canrefreshMap[tenant] + return exists +} + +func (c *Client) addTenantIntoCanRefresh(tenant string) { + c.refreshMu.Lock() + defer c.refreshMu.Unlock() + canrefreshMap := c.canRefresh.Load().(map[string]struct{}) + canrefreshMap[tenant] = struct{}{} + c.canRefresh.Store(canrefreshMap) +} + +func (c *Client) removeTenantFromCanRefresh(tenant string) { + c.refreshMu.Lock() + defer c.refreshMu.Unlock() + canrefreshMap := c.canRefresh.Load().(map[string]struct{}) + delete(canrefreshMap, tenant) + c.canRefresh.Store(canrefreshMap) +} + // shouldRefresh returns true if the token should be refreshed. -func shouldRefresh(t time.Time) bool { - return !t.IsZero() && t.Before(GetCurrentTime()) +func (b Client) shouldRefresh(t time.Time, tId string) bool { + if !b.doesTenantExists(tId) { + b.addTenantIntoCanRefresh(tId) + println("success") + return !t.IsZero() && t.Before(GetCurrentTime()) + } + println("fail the exist") + return false } func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) { From 0425d56cf1ad2cc335fedacb17978f7c798ca97f Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 11 Feb 2025 14:48:30 +0000 Subject: [PATCH 13/29] Added test for force refresh once for each tenant --- apps/confidential/confidential_test.go | 132 +++++++++++++++++++++++++ apps/internal/base/base.go | 41 ++++---- apps/internal/base/base_test.go | 4 +- 3 files changed, 152 insertions(+), 25 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 1fc8d986..357f1ff0 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -17,7 +17,9 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strings" + "sync" "testing" "time" @@ -764,6 +766,136 @@ func TestNewCredFromTokenProvider(t *testing.T) { } } +func TestRefreshInMultipleRequests(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + firstToken := "first token" + secondToken := "new token" + lmo := "login.microsoftonline.com" + refreshIn := 43200 + expiresIn := 86400 + + t.Run("Test for refresh multiple request", func(t *testing.T) { + originalTime := base.GetCurrentTime + defer func() { + base.GetCurrentTime = originalTime + }() + // Create a mock client and append mock responses + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "firstTenant"))) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "secondTenant"))) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + // Create the client instance + client, err := New(fmt.Sprintf(authorityFmt, lmo, "firstTenant"), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) + if err != nil { + t.Fatal(err) + } + // Acquire the first token for first tenant + ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("firstTenant")) + if err != nil { + t.Fatal(err) + } + // Assert the first token is returned + if ar.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) + } + + // Acquire the first token for second tenant + arSecond, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("secondTenant")) + if err != nil { + t.Fatal(err) + } + // Assert the first token is returned + if arSecond.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, arSecond.AccessToken) + } + + fixedTime := time.Now().Add(time.Duration(43400) * time.Second) + base.GetCurrentTime = func() time.Time { + return fixedTime + } + var wg sync.WaitGroup + type tokenResult struct { + Token string + Tenant string + } + ch := make(chan tokenResult, 10) + var mu sync.Mutex // Mutex to protect access to expectedResponse + expectedResponse := []tokenResult{ + {Token: "new token", Tenant: "firstTentant"}, + {Token: "new token", Tenant: "firstTentant"}, + {Token: "first token", Tenant: "secondTentant"}, + {Token: "first token", Tenant: "secondTentant"}, + {Token: "first token", Tenant: "secondTentant"}, + {Token: "first token", Tenant: "firstTentant"}, + {Token: "first token", Tenant: "secondTentant"}, + {Token: "new token", Tenant: "firstTentant"}, + {Token: "new token", Tenant: "firstTentant"}, + {Token: "first token", Tenant: "secondTentant"}, + } + gotResponse := []tokenResult{} + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { + time.Sleep(150 * time.Millisecond) + }), + ) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { + time.Sleep(100 * time.Millisecond) + mu.Lock() + base.GetCurrentTime = originalTime + mu.Unlock() + }), + ) + for i := 0; i < 5; i++ { + wg.Add(1) + wg.Add(1) + time.Sleep(50 * time.Millisecond) + go func() { + defer wg.Done() + ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant")) + if err != nil { + t.Error(err) + return + } + ch <- tokenResult{Token: ar.AccessToken, Tenant: "firstTentant"} // Send result to channel + }() + go func() { + time.Sleep(50 * time.Millisecond) + defer wg.Done() + ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant")) + if err != nil { + t.Error(err) + return + } + ch <- tokenResult{Token: ar.AccessToken, Tenant: "secondTentant"} // Send result to channel + }() + } + // Waiting for all goroutines to finish + go func() { + for s := range ch { + mu.Lock() // Acquire lock before modifying expectedResponse + gotResponse = append(gotResponse, s) + println(s.Token, s.Tenant) + mu.Unlock() // Release lock after modifying expectedResponse + } + if reflect.DeepEqual(gotResponse, expectedResponse) { + t.Error("gotResponse and expectedResponse are not equal") + } + }() + wg.Wait() + close(ch) + }) + +} + func TestRefreshIn(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 162d701b..77023c96 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -11,7 +11,6 @@ import ( "reflect" "strings" "sync" - "sync/atomic" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" @@ -170,8 +169,8 @@ type Client struct { AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). cacheAccessor cache.ExportReplace cacheAccessorMu *sync.RWMutex - canRefresh *atomic.Value - refreshMu *sync.Mutex + canRefresh map[string]struct{} + refreshMu *sync.RWMutex } // Option is an optional argument to the New constructor. @@ -248,10 +247,9 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O cacheAccessorMu: &sync.RWMutex{}, manager: storage.New(token), pmanager: storage.NewPartitionedManager(token), - canRefresh: &atomic.Value{}, - refreshMu: &sync.Mutex{}, + canRefresh: make(map[string]struct{}), + refreshMu: &sync.RWMutex{}, } - client.canRefresh.Store(make(map[string]struct{})) for _, o := range options { if err = o(&client); err != nil { break @@ -350,12 +348,16 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if err != nil { return ar, err } + + // go routine call 100 times // ignore cached access tokens when given claims if silent.Claims == "" { ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { if b.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T, tenant) { - defer b.removeTenantFromCanRefresh(tenant) + b.refreshMu.Lock() + b.removeTenantFromCanRefresh(tenant) + defer b.refreshMu.Unlock() if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { return b.AuthResultFromToken(ctx, authParams, tr) } else { @@ -374,6 +376,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen } } } + } ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) return ar, err @@ -490,37 +493,29 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au var GetCurrentTime = time.Now func (c Client) doesTenantExists(tenant string) bool { - c.refreshMu.Lock() - defer c.refreshMu.Unlock() - canrefreshMap := c.canRefresh.Load().(map[string]struct{}) - _, exists := canrefreshMap[tenant] + _, exists := c.canRefresh[tenant] return exists } func (c *Client) addTenantIntoCanRefresh(tenant string) { - c.refreshMu.Lock() - defer c.refreshMu.Unlock() - canrefreshMap := c.canRefresh.Load().(map[string]struct{}) - canrefreshMap[tenant] = struct{}{} - c.canRefresh.Store(canrefreshMap) + c.canRefresh[tenant] = struct{}{} } func (c *Client) removeTenantFromCanRefresh(tenant string) { - c.refreshMu.Lock() - defer c.refreshMu.Unlock() - canrefreshMap := c.canRefresh.Load().(map[string]struct{}) - delete(canrefreshMap, tenant) - c.canRefresh.Store(canrefreshMap) + delete(c.canRefresh, tenant) } // shouldRefresh returns true if the token should be refreshed. func (b Client) shouldRefresh(t time.Time, tId string) bool { + b.refreshMu.RLock() if !b.doesTenantExists(tId) { + b.refreshMu.RUnlock() + b.refreshMu.Lock() + defer b.refreshMu.Unlock() b.addTenantIntoCanRefresh(tId) - println("success") return !t.IsZero() && t.Before(GetCurrentTime()) } - println("fail the exist") + b.refreshMu.RUnlock() return false } diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 69becd23..d25cffa5 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -448,7 +448,7 @@ func TestAuthResultFromStorage(t *testing.T) { func TestShouldRefresh(t *testing.T) { // Get the current time to use for comparison now := time.Now() - + client := fakeClient(t) tests := []struct { name string input time.Time @@ -473,7 +473,7 @@ func TestShouldRefresh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := shouldRefresh(tt.input) + result := client.shouldRefresh(tt.input, tt.name) if result != tt.expected { t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected) } From d82d813d74bb2480ab6676119f986093136d7748 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 11 Feb 2025 14:49:54 +0000 Subject: [PATCH 14/29] Update confidential_test.go --- apps/confidential/confidential_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 357f1ff0..b7d85168 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -883,7 +883,6 @@ func TestRefreshInMultipleRequests(t *testing.T) { for s := range ch { mu.Lock() // Acquire lock before modifying expectedResponse gotResponse = append(gotResponse, s) - println(s.Token, s.Tenant) mu.Unlock() // Release lock after modifying expectedResponse } if reflect.DeepEqual(gotResponse, expectedResponse) { From d817f7d3738baf0700a3ad53c942951fcb68bf21 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 11 Feb 2025 15:26:02 +0000 Subject: [PATCH 15/29] Update confidential_test.go --- apps/confidential/confidential_test.go | 42 ++++++++++++++------------ 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index b7d85168..45718b2e 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -17,7 +17,6 @@ import ( "net/url" "os" "path/filepath" - "reflect" "strings" "sync" "testing" @@ -826,35 +825,26 @@ func TestRefreshInMultipleRequests(t *testing.T) { Token string Tenant string } - ch := make(chan tokenResult, 10) + tokenOneCchecker := false + tokenTwoCchecker := false + + ch := make(chan tokenResult, 14) var mu sync.Mutex // Mutex to protect access to expectedResponse - expectedResponse := []tokenResult{ - {Token: "new token", Tenant: "firstTentant"}, - {Token: "new token", Tenant: "firstTentant"}, - {Token: "first token", Tenant: "secondTentant"}, - {Token: "first token", Tenant: "secondTentant"}, - {Token: "first token", Tenant: "secondTentant"}, - {Token: "first token", Tenant: "firstTentant"}, - {Token: "first token", Tenant: "secondTentant"}, - {Token: "new token", Tenant: "firstTentant"}, - {Token: "new token", Tenant: "firstTentant"}, - {Token: "first token", Tenant: "secondTentant"}, - } gotResponse := []tokenResult{} mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { time.Sleep(150 * time.Millisecond) }), ) mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { time.Sleep(100 * time.Millisecond) mu.Lock() base.GetCurrentTime = originalTime mu.Unlock() }), ) - for i := 0; i < 5; i++ { + for i := 0; i < 7; i++ { wg.Add(1) wg.Add(1) time.Sleep(50 * time.Millisecond) @@ -865,6 +855,13 @@ func TestRefreshInMultipleRequests(t *testing.T) { t.Error(err) return } + if ar.AccessToken == secondToken+"firstTenant" && ar.Metadata.TokenSource == base.IdentityProvider { + if tokenOneCchecker { + t.Error("Error can only call this once") + } else { + tokenOneCchecker = true + } + } ch <- tokenResult{Token: ar.AccessToken, Tenant: "firstTentant"} // Send result to channel }() go func() { @@ -875,6 +872,13 @@ func TestRefreshInMultipleRequests(t *testing.T) { t.Error(err) return } + if ar.AccessToken == secondToken+"secondTenant" && ar.Metadata.TokenSource == base.IdentityProvider { + if tokenTwoCchecker { + t.Error("Error can only call this once") + } else { + tokenTwoCchecker = true + } + } ch <- tokenResult{Token: ar.AccessToken, Tenant: "secondTentant"} // Send result to channel }() } @@ -885,8 +889,8 @@ func TestRefreshInMultipleRequests(t *testing.T) { gotResponse = append(gotResponse, s) mu.Unlock() // Release lock after modifying expectedResponse } - if reflect.DeepEqual(gotResponse, expectedResponse) { - t.Error("gotResponse and expectedResponse are not equal") + if !tokenOneCchecker && !tokenTwoCchecker { + t.Error("Error should be called at least once") } }() wg.Wait() From b955ede0f0dfdeeabbe20f12282f87f81094c198 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 11 Feb 2025 15:41:27 +0000 Subject: [PATCH 16/29] Update confidential_test.go --- apps/confidential/confidential_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 45718b2e..d38875df 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -845,8 +845,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { }), ) for i := 0; i < 7; i++ { - wg.Add(1) - wg.Add(1) + wg.Add(2) time.Sleep(50 * time.Millisecond) go func() { defer wg.Done() From c5393f5377e7bf2c5aa311928b383048d7f70f24 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 11 Feb 2025 15:55:57 +0000 Subject: [PATCH 17/29] Refactor code --- apps/confidential/confidential_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index d38875df..0fc6c455 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -825,8 +825,8 @@ func TestRefreshInMultipleRequests(t *testing.T) { Token string Tenant string } - tokenOneCchecker := false - tokenTwoCchecker := false + firstTenantChecker := false + secondTenantChecker := false ch := make(chan tokenResult, 14) var mu sync.Mutex // Mutex to protect access to expectedResponse @@ -855,10 +855,10 @@ func TestRefreshInMultipleRequests(t *testing.T) { return } if ar.AccessToken == secondToken+"firstTenant" && ar.Metadata.TokenSource == base.IdentityProvider { - if tokenOneCchecker { + if firstTenantChecker { t.Error("Error can only call this once") } else { - tokenOneCchecker = true + firstTenantChecker = true } } ch <- tokenResult{Token: ar.AccessToken, Tenant: "firstTentant"} // Send result to channel @@ -872,10 +872,10 @@ func TestRefreshInMultipleRequests(t *testing.T) { return } if ar.AccessToken == secondToken+"secondTenant" && ar.Metadata.TokenSource == base.IdentityProvider { - if tokenTwoCchecker { + if secondTenantChecker { t.Error("Error can only call this once") } else { - tokenTwoCchecker = true + secondTenantChecker = true } } ch <- tokenResult{Token: ar.AccessToken, Tenant: "secondTentant"} // Send result to channel @@ -888,7 +888,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { gotResponse = append(gotResponse, s) mu.Unlock() // Release lock after modifying expectedResponse } - if !tokenOneCchecker && !tokenTwoCchecker { + if !firstTenantChecker && !secondTenantChecker { t.Error("Error should be called at least once") } }() From 2700b6457d09f8bddaa737bd2eb4286dffc12810 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 13 Feb 2025 18:10:49 +0000 Subject: [PATCH 18/29] Update confidential_test.go --- apps/confidential/confidential_test.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 0fc6c455..738d10a1 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -828,25 +828,22 @@ func TestRefreshInMultipleRequests(t *testing.T) { firstTenantChecker := false secondTenantChecker := false - ch := make(chan tokenResult, 14) + ch := make(chan tokenResult, 10000) var mu sync.Mutex // Mutex to protect access to expectedResponse gotResponse := []tokenResult{} mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { - time.Sleep(150 * time.Millisecond) }), ) mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { - time.Sleep(100 * time.Millisecond) mu.Lock() base.GetCurrentTime = originalTime mu.Unlock() }), ) - for i := 0; i < 7; i++ { + for i := 0; i < 10000; i++ { wg.Add(2) - time.Sleep(50 * time.Millisecond) go func() { defer wg.Done() ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant")) @@ -864,7 +861,6 @@ func TestRefreshInMultipleRequests(t *testing.T) { ch <- tokenResult{Token: ar.AccessToken, Tenant: "firstTentant"} // Send result to channel }() go func() { - time.Sleep(50 * time.Millisecond) defer wg.Done() ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant")) if err != nil { From e35eb9d5e7cdd97aa58164ac9cc2d414ba7891d0 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Mon, 17 Feb 2025 15:52:29 +0000 Subject: [PATCH 19/29] Updated some tests to adapt to change in time --- apps/internal/oauth/ops/accesstokens/accesstokens_test.go | 6 +++--- apps/managedidentity/managedidentity_test.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go index a0c18b45..a38a0f24 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens_test.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens_test.go @@ -769,7 +769,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { }`, want: TokenResponse{ AccessToken: "secret", - ExpiresOn: time.Unix(86400, 0), + ExpiresOn: time.Now().Add(time.Hour * 24), RefreshOn: internalTime.DurationTime{T: time.Unix(43200, 0)}, ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, @@ -793,7 +793,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { }`, want: TokenResponse{ AccessToken: "secret", - ExpiresOn: time.Unix(86400, 0), + ExpiresOn: time.Now().Add(time.Hour * 24), RefreshOn: internalTime.DurationTime{T: time.Unix(43199, 0)}, ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, @@ -816,7 +816,7 @@ func TestTokenResponseUnmarshal(t *testing.T) { }`, want: TokenResponse{ AccessToken: "secret", - ExpiresOn: time.Unix(86400, 0), + ExpiresOn: time.Now().Add(time.Hour * 24), ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86400, 0)}, GrantedScopes: Scopes{Slice: []string{"openid", "profile"}}, ClientInfo: ClientInfo{ diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index ce9e8cfc..b5565efa 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -347,7 +347,7 @@ func TestCacheScopes(t *testing.T) { } for _, r := range []string{"A", "B/.default"} { - mc.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(r, "", "", "", 3600))) + mc.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(r, "", "", "", 3600, 3600))) for i := 0; i < 2; i++ { ar, err := client.AcquireToken(context.Background(), r) if err != nil { From 1cea155bfc5a649fce0fa1e8d3624b91fb9d60b7 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 19 Feb 2025 16:27:14 +0000 Subject: [PATCH 20/29] Added RefreshIn logic for Managed Identity --- apps/confidential/confidential_test.go | 2 +- apps/internal/base/base.go | 18 ---- apps/managedidentity/managedidentity.go | 26 ++++++ apps/managedidentity/managedidentity_test.go | 90 ++++++++++++++++++++ 4 files changed, 117 insertions(+), 19 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index dfbf9132..f09c9981 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -879,7 +879,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { secondTenantChecker := false ch := make(chan tokenResult, 10000) - var mu sync.Mutex // Mutex to protect access to expectedResponse + var mu sync.Mutex gotResponse := []tokenResult{} mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index a34c9bf9..b1c35357 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -6,7 +6,6 @@ package base import ( "context" "fmt" - "net/http" "net/url" "reflect" "strings" @@ -348,7 +347,6 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen return ar, err } - // go routine call 100 times // ignore cached access tokens when given claims if silent.Claims == "" { ar, err = AuthResultFromStorage(storageTokenResponse) @@ -359,23 +357,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen defer b.refreshMu.Unlock() if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { return b.AuthResultFromToken(ctx, authParams, tr) - } else { - var callErr *errors.CallErr - if errors.As(er, &callErr) { - switch callErr.Resp.StatusCode { - case http.StatusRequestTimeout, // 408 - http.StatusTooManyRequests, // 429 - http.StatusInternalServerError, // 500 - http.StatusBadGateway, // 502 - http.StatusServiceUnavailable, // 503 - http.StatusGatewayTimeout: // 504 - default: - // return empty token for non handable error - return AuthResult{}, er - } - } } - } ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) return ar, err diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 4e274869..a48e851c 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -20,6 +20,7 @@ import ( "path/filepath" "runtime" "strings" + "sync/atomic" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" @@ -165,6 +166,7 @@ type Client struct { source Source authParams authority.AuthParams retryPolicyEnabled bool + canRefresh *atomic.Int32 } type AcquireTokenOptions struct { @@ -247,11 +249,13 @@ func New(id ID, options ...ClientOption) (Client, error) { default: return Client{}, fmt.Errorf("unsupported type %T", id) } + var zero atomic.Int32 client := Client{ miType: id, httpClient: shared.DefaultClient, retryPolicyEnabled: true, source: source, + canRefresh: &zero, } for _, option := range options { option(&client) @@ -291,6 +295,18 @@ func GetSource() (Source, error) { return DefaultToIMDS, nil } +// This function wraps time.Now() and is used for refreshing the application +// was created to test the function against refreshin +var GetCurrentTime = time.Now + +// shouldRefresh returns true if the token should be refreshed. +func (b Client) shouldRefresh(t time.Time) bool { + if b.canRefresh.CompareAndSwap(0, 1) { + return !t.IsZero() && t.Before(GetCurrentTime()) + } + return false +} + // Acquires tokens from the configured managed identity on an azure resource. // // Resource: scopes application is requesting access to @@ -311,10 +327,20 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac } ar, err := base.AuthResultFromStorage(storageTokenResponse) if err == nil { + if c.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { + defer c.canRefresh.Store(1) + if tr, er := c.getToken(ctx, resource); er == nil { + return tr, nil + } + } ar.AccessToken, err = c.authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) return ar, err } } + return c.getToken(ctx, resource) +} + +func (c Client) getToken(ctx context.Context, resource string) (base.AuthResult, error) { switch c.source { case AzureArc: return c.acquireTokenForAzureArc(ctx, resource) diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index b5565efa..974ced44 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -13,6 +13,7 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" "time" @@ -1131,3 +1132,92 @@ func TestCreatingIMDSClient(t *testing.T) { }) } } + +func TestRefreshInMultipleRequests(t *testing.T) { + firstToken := "first token" + secondToken := "new token" + refreshIn := 43200 + expiresIn := 86400 + resource := "https://resource/.default" + miType := SystemAssigned() + setEnvVars(t, CloudShell) + + t.Run("Test for refresh multiple request", func(t *testing.T) { + originalTime := GetCurrentTime + defer func() { + GetCurrentTime = originalTime + }() + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + // Create a mock client and append mock responses + mockClient := mock.Client{} + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + // Create the client instance + client, err := New(miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + ar, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + // Assert the first token is returned + if ar.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) + } + fixedTime := time.Now().Add(time.Duration(43400) * time.Second) + GetCurrentTime = func() time.Time { + return fixedTime + } + var wg sync.WaitGroup + + requestChecker := false + + ch := make(chan string, 10000) + var mu sync.Mutex // Mutex to protect access to expectedResponse + gotResponse := []string{} + mockClient.AppendResponse( + mock.WithCallback(func(*http.Request) { + GetCurrentTime = originalTime + }), + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { + }), + ) + + for i := 0; i < 10000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + ar, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Error(err) + return + } + if ar.AccessToken == secondToken && ar.Metadata.TokenSource == base.IdentityProvider { + if requestChecker { + t.Error("Error can only call this only once") + } else { + requestChecker = true + } + } + ch <- ar.AccessToken + }() + } + // Waiting for all goroutines to finish + go func() { + for s := range ch { + mu.Lock() // Acquire lock before modifying expectedResponse + gotResponse = append(gotResponse, s) + mu.Unlock() // Release lock after modifying expectedResponse + } + if !requestChecker { + t.Error("Error should be called at least once") + } + }() + wg.Wait() + close(ch) + }) +} From e6a3b29a19705c3e1fc69cf285e3a95d3f9f4f0b Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 20 Feb 2025 19:05:42 +0000 Subject: [PATCH 21/29] Added a sync http client and updated tests --- apps/confidential/confidential_test.go | 32 +++------- apps/internal/base/base.go | 49 ++++++--------- apps/internal/base/base_test.go | 2 +- apps/internal/mock/syncmock.go | 49 +++++++++++++++ apps/managedidentity/managedidentity.go | 18 +++--- apps/managedidentity/managedidentity_test.go | 66 ++++++++++++++------ 6 files changed, 132 insertions(+), 84 deletions(-) create mode 100644 apps/internal/mock/syncmock.go diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index f09c9981..e3a643a7 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -815,6 +815,7 @@ func TestNewCredFromTokenProvider(t *testing.T) { } } +//  go test -race -timeout 30s -run ^TestRefreshInMultipleRequests$ github.com/AzureAD/microsoft-authentication-library-for-go/apps/confideintial func TestRefreshInMultipleRequests(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { @@ -832,7 +833,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { base.GetCurrentTime = originalTime }() // Create a mock client and append mock responses - mockClient := mock.Client{} + mockClient := mock.SyncClient{} mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "firstTenant"))) mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), @@ -871,27 +872,15 @@ func TestRefreshInMultipleRequests(t *testing.T) { return fixedTime } var wg sync.WaitGroup - type tokenResult struct { - Token string - Tenant string - } firstTenantChecker := false secondTenantChecker := false - ch := make(chan tokenResult, 10000) - var mu sync.Mutex - gotResponse := []tokenResult{} mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { - }), + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200))), ) mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { - mu.Lock() - base.GetCurrentTime = originalTime - mu.Unlock() - }), - ) + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn+44200)))) + for i := 0; i < 10000; i++ { wg.Add(2) go func() { @@ -908,7 +897,6 @@ func TestRefreshInMultipleRequests(t *testing.T) { firstTenantChecker = true } } - ch <- tokenResult{Token: ar.AccessToken, Tenant: "firstTentant"} // Send result to channel }() go func() { defer wg.Done() @@ -924,22 +912,16 @@ func TestRefreshInMultipleRequests(t *testing.T) { secondTenantChecker = true } } - ch <- tokenResult{Token: ar.AccessToken, Tenant: "secondTentant"} // Send result to channel }() } // Waiting for all goroutines to finish go func() { - for s := range ch { - mu.Lock() // Acquire lock before modifying expectedResponse - gotResponse = append(gotResponse, s) - mu.Unlock() // Release lock after modifying expectedResponse - } - if !firstTenantChecker && !secondTenantChecker { + wg.Wait() + if !secondTenantChecker && !firstTenantChecker { t.Error("Error should be called at least once") } }() wg.Wait() - close(ch) }) } diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index b1c35357..09f6174c 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -10,6 +10,7 @@ import ( "reflect" "strings" "sync" + "sync/atomic" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" @@ -167,7 +168,7 @@ type Client struct { AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New(). cacheAccessor cache.ExportReplace cacheAccessorMu *sync.RWMutex - canRefresh map[string]struct{} + canRefresh map[string]*atomic.Value refreshMu *sync.RWMutex } @@ -245,7 +246,7 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O cacheAccessorMu: &sync.RWMutex{}, manager: storage.New(token), pmanager: storage.NewPartitionedManager(token), - canRefresh: make(map[string]struct{}), + canRefresh: make(map[string]*atomic.Value), refreshMu: &sync.RWMutex{}, } for _, o := range options { @@ -351,12 +352,20 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if silent.Claims == "" { ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { - if b.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T, tenant) { + if b.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { b.refreshMu.Lock() - b.removeTenantFromCanRefresh(tenant) - defer b.refreshMu.Unlock() - if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { - return b.AuthResultFromToken(ctx, authParams, tr) + if _, exists := b.canRefresh[tenant]; !exists { + var empty atomic.Value + empty.Store(false) + b.canRefresh[tenant] = &empty + } + refreshValue := b.canRefresh[tenant] + b.refreshMu.Unlock() + if refreshValue.CompareAndSwap(false, true) { + defer refreshValue.Store(false) + if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { + return b.AuthResultFromToken(ctx, authParams, tr) + } } } ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken) @@ -473,31 +482,9 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au // was created to test the function against refreshin var GetCurrentTime = time.Now -func (c Client) doesTenantExists(tenant string) bool { - _, exists := c.canRefresh[tenant] - return exists -} - -func (c *Client) addTenantIntoCanRefresh(tenant string) { - c.canRefresh[tenant] = struct{}{} -} - -func (c *Client) removeTenantFromCanRefresh(tenant string) { - delete(c.canRefresh, tenant) -} - // shouldRefresh returns true if the token should be refreshed. -func (b Client) shouldRefresh(t time.Time, tId string) bool { - b.refreshMu.RLock() - if !b.doesTenantExists(tId) { - b.refreshMu.RUnlock() - b.refreshMu.Lock() - defer b.refreshMu.Unlock() - b.addTenantIntoCanRefresh(tId) - return !t.IsZero() && t.Before(GetCurrentTime()) - } - b.refreshMu.RUnlock() - return false +func (b *Client) shouldRefresh(t time.Time) bool { + return !t.IsZero() && t.Before(GetCurrentTime()) } func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) { diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index c0eacb9d..3fd52957 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -473,7 +473,7 @@ func TestShouldRefresh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := client.shouldRefresh(tt.input, tt.name) + result := client.shouldRefresh(tt.input) if result != tt.expected { t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected) } diff --git a/apps/internal/mock/syncmock.go b/apps/internal/mock/syncmock.go new file mode 100644 index 00000000..ec8d0360 --- /dev/null +++ b/apps/internal/mock/syncmock.go @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package mock + +import ( + "bytes" + "fmt" + "io" + "net/http" + "sync" +) + +// Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence. +type SyncClient struct { + mu sync.Mutex + resp []response +} + +func (c *SyncClient) AppendResponse(opts ...responseOption) { + c.mu.Lock() + defer c.mu.Unlock() + + r := response{code: http.StatusOK, headers: http.Header{}} + for _, o := range opts { + o.apply(&r) + } + c.resp = append(c.resp, r) +} + +func (c *SyncClient) Do(req *http.Request) (*http.Response, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if len(c.resp) == 0 { + panic(fmt.Sprintf(`no response for "%s"`, req.URL.String())) + } + resp := c.resp[0] + c.resp = c.resp[1:] + if resp.callback != nil { + resp.callback(req) + } + res := http.Response{Header: resp.headers, StatusCode: resp.code} + res.Body = io.NopCloser(bytes.NewReader(resp.body)) + return &res, nil +} + +// CloseIdleConnections implements the comm.HTTPClient interface +func (*SyncClient) CloseIdleConnections() {} diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index a48e851c..a2a563d4 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -166,7 +166,7 @@ type Client struct { source Source authParams authority.AuthParams retryPolicyEnabled bool - canRefresh *atomic.Int32 + canRefresh *atomic.Value } type AcquireTokenOptions struct { @@ -249,7 +249,8 @@ func New(id ID, options ...ClientOption) (Client, error) { default: return Client{}, fmt.Errorf("unsupported type %T", id) } - var zero atomic.Int32 + var zero atomic.Value = atomic.Value{} + zero.Store(false) client := Client{ miType: id, httpClient: shared.DefaultClient, @@ -297,14 +298,17 @@ func GetSource() (Source, error) { // This function wraps time.Now() and is used for refreshing the application // was created to test the function against refreshin -var GetCurrentTime = time.Now +var getCurrentTime = time.Now // shouldRefresh returns true if the token should be refreshed. func (b Client) shouldRefresh(t time.Time) bool { - if b.canRefresh.CompareAndSwap(0, 1) { - return !t.IsZero() && t.Before(GetCurrentTime()) + if t.IsZero() || t.After(getCurrentTime()) { + return false } - return false + if !b.canRefresh.CompareAndSwap(false, true) { + return false + } + return true } // Acquires tokens from the configured managed identity on an azure resource. @@ -328,7 +332,7 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac ar, err := base.AuthResultFromStorage(storageTokenResponse) if err == nil { if c.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { - defer c.canRefresh.Store(1) + defer c.canRefresh.Store(false) if tr, er := c.getToken(ctx, resource); er == nil { return tr, nil } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 974ced44..559f1294 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1143,15 +1143,15 @@ func TestRefreshInMultipleRequests(t *testing.T) { setEnvVars(t, CloudShell) t.Run("Test for refresh multiple request", func(t *testing.T) { - originalTime := GetCurrentTime + originalTime := getCurrentTime defer func() { - GetCurrentTime = originalTime + getCurrentTime = originalTime }() before := cacheManager defer func() { cacheManager = before }() cacheManager = storage.New(nil) // Create a mock client and append mock responses - mockClient := mock.Client{} + mockClient := mock.SyncClient{} mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), ) @@ -1168,25 +1168,18 @@ func TestRefreshInMultipleRequests(t *testing.T) { if ar.AccessToken != firstToken { t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) } + fixedTime := time.Now().Add(time.Duration(43400) * time.Second) - GetCurrentTime = func() time.Time { + getCurrentTime = func() time.Time { return fixedTime } var wg sync.WaitGroup - requestChecker := false - ch := make(chan string, 10000) - var mu sync.Mutex // Mutex to protect access to expectedResponse - gotResponse := []string{} mockClient.AppendResponse( - mock.WithCallback(func(*http.Request) { - GetCurrentTime = originalTime - }), - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))), mock.WithCallback(func(req *http.Request) { + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn+43200))), mock.WithCallback(func(req *http.Request) { }), ) - for i := 0; i < 10000; i++ { wg.Add(1) go func() { @@ -1203,21 +1196,54 @@ func TestRefreshInMultipleRequests(t *testing.T) { requestChecker = true } } - ch <- ar.AccessToken }() } // Waiting for all goroutines to finish go func() { - for s := range ch { - mu.Lock() // Acquire lock before modifying expectedResponse - gotResponse = append(gotResponse, s) - mu.Unlock() // Release lock after modifying expectedResponse - } + wg.Wait() if !requestChecker { t.Error("Error should be called at least once") } }() wg.Wait() - close(ch) }) } + +func TestShouldRefresh(t *testing.T) { + // Get the current time to use for comparison + now := time.Now() + tests := []struct { + name string + input time.Time + expected bool + }{ + { + name: "Zero time", + input: time.Time{}, // Zero time + expected: false, // Should return false because it's zero time + }, + { + name: "Future time", + input: now.Add(time.Hour), // 1 hour in the future + expected: false, // Should return false because it's in the future + }, + { + name: "Past time", + input: now.Add(-time.Hour), // 1 hour in the past + expected: true, // Should return true because it's in the past + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := New(SystemAssigned()) + if err != nil { + t.Fatal(err) + } + result := client.shouldRefresh(tt.input) + if result != tt.expected { + t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected) + } + }) + } +} From 76218c5c211eb3c06f71ca6f2d25b4a5864401fc Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 21 Feb 2025 11:30:57 +0000 Subject: [PATCH 22/29] Updated the code --- apps/confidential/confidential_test.go | 178 +++++++++---------- apps/internal/base/base.go | 18 +- apps/internal/base/base_test.go | 3 +- apps/internal/mock/mock.go | 6 + apps/internal/mock/syncmock.go | 49 ----- apps/managedidentity/managedidentity.go | 2 +- apps/managedidentity/managedidentity_test.go | 117 ++++++------ 7 files changed, 162 insertions(+), 211 deletions(-) delete mode 100644 apps/internal/mock/syncmock.go diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index e3a643a7..6298dc92 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -827,102 +827,102 @@ func TestRefreshInMultipleRequests(t *testing.T) { refreshIn := 43200 expiresIn := 86400 - t.Run("Test for refresh multiple request", func(t *testing.T) { - originalTime := base.GetCurrentTime - defer func() { - base.GetCurrentTime = originalTime - }() - // Create a mock client and append mock responses - mockClient := mock.SyncClient{} - mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "firstTenant"))) - mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), - ) - mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "secondTenant"))) - mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), - ) - // Create the client instance - client, err := New(fmt.Sprintf(authorityFmt, lmo, "firstTenant"), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) - if err != nil { - t.Fatal(err) - } - // Acquire the first token for first tenant - ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("firstTenant")) - if err != nil { - t.Fatal(err) - } - // Assert the first token is returned - if ar.AccessToken != firstToken { - t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) - } + originalTime := base.GetCurrentTime + defer func() { + base.GetCurrentTime = originalTime + }() + // Create a mock client and append mock responses + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "firstTenant"))) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "secondTenant"))) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + // Create the client instance + client, err := New(fmt.Sprintf(authorityFmt, lmo, "firstTenant"), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) + if err != nil { + t.Fatal(err) + } + // Acquire the first token for first tenant + ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("firstTenant")) + if err != nil { + t.Fatal(err) + } + // Assert the first token is returned + if ar.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) + } + // Acquire the first token for second tenant + arSecond, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("secondTenant")) + if err != nil { + t.Fatal(err) + } + if arSecond.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, arSecond.AccessToken) + } + fixedTime := time.Now().Add(time.Duration(43400) * time.Second) + base.GetCurrentTime = func() time.Time { + return fixedTime + } + var wg sync.WaitGroup + done := make(chan struct{}) - // Acquire the first token for second tenant - arSecond, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("secondTenant")) - if err != nil { - t.Fatal(err) - } - // Assert the first token is returned - if arSecond.AccessToken != firstToken { - t.Fatalf("wanted %q, got %q", firstToken, arSecond.AccessToken) - } + firstTenantChecker := false + secondTenantChecker := false + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200))), + ) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn+44200)))) - fixedTime := time.Now().Add(time.Duration(43400) * time.Second) - base.GetCurrentTime = func() time.Time { - return fixedTime - } - var wg sync.WaitGroup - firstTenantChecker := false - secondTenantChecker := false - - mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200))), - ) - mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn+44200)))) - - for i := 0; i < 10000; i++ { - wg.Add(2) - go func() { - defer wg.Done() - ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant")) - if err != nil { - t.Error(err) - return - } - if ar.AccessToken == secondToken+"firstTenant" && ar.Metadata.TokenSource == base.IdentityProvider { - if firstTenantChecker { - t.Error("Error can only call this once") - } else { - firstTenantChecker = true - } - } - }() - go func() { - defer wg.Done() - ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant")) - if err != nil { - t.Error(err) - return - } - if ar.AccessToken == secondToken+"secondTenant" && ar.Metadata.TokenSource == base.IdentityProvider { - if secondTenantChecker { - t.Error("Error can only call this once") - } else { - secondTenantChecker = true - } + for i := 0; i < 10000; i++ { + wg.Add(2) + go func() { + defer wg.Done() + ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant")) + if err != nil { + t.Error(err) + return + } + if ar.AccessToken == secondToken+"firstTenant" && ar.Metadata.TokenSource == base.IdentityProvider { + if firstTenantChecker { + t.Error("Error can only call this once") + } else { + firstTenantChecker = true } - }() - } - // Waiting for all goroutines to finish + } + }() go func() { - wg.Wait() - if !secondTenantChecker && !firstTenantChecker { - t.Error("Error should be called at least once") + defer wg.Done() + ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant")) + if err != nil { + t.Error(err) + return + } + if ar.AccessToken == secondToken+"secondTenant" && ar.Metadata.TokenSource == base.IdentityProvider { + if secondTenantChecker { + t.Error("Error can only call this once") + } else { + secondTenantChecker = true + } } }() + } + // Wait for all goroutines in a separate goroutine + go func() { wg.Wait() - }) + close(done) + }() + + // Wait for all goroutines to complete + <-done + + if !secondTenantChecker && !firstTenantChecker { + t.Error("Error should be called at least once") + } } diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 09f6174c..9ae76870 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -169,7 +169,7 @@ type Client struct { cacheAccessor cache.ExportReplace cacheAccessorMu *sync.RWMutex canRefresh map[string]*atomic.Value - refreshMu *sync.RWMutex + refreshMu *sync.Mutex } // Option is an optional argument to the New constructor. @@ -247,7 +247,7 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O manager: storage.New(token), pmanager: storage.NewPartitionedManager(token), canRefresh: make(map[string]*atomic.Value), - refreshMu: &sync.RWMutex{}, + refreshMu: &sync.Mutex{}, } for _, o := range options { if err = o(&client); err != nil { @@ -352,14 +352,14 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if silent.Claims == "" { ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { - if b.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { + if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { b.refreshMu.Lock() - if _, exists := b.canRefresh[tenant]; !exists { - var empty atomic.Value - empty.Store(false) - b.canRefresh[tenant] = &empty + refreshValue, exists := b.canRefresh[tenant] + if !exists { + refreshValue = &atomic.Value{} + refreshValue.Store(false) + b.canRefresh[tenant] = refreshValue } - refreshValue := b.canRefresh[tenant] b.refreshMu.Unlock() if refreshValue.CompareAndSwap(false, true) { defer refreshValue.Store(false) @@ -483,7 +483,7 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au var GetCurrentTime = time.Now // shouldRefresh returns true if the token should be refreshed. -func (b *Client) shouldRefresh(t time.Time) bool { +func shouldRefresh(t time.Time) bool { return !t.IsZero() && t.Before(GetCurrentTime()) } diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index 3fd52957..e07ea110 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -448,7 +448,6 @@ func TestAuthResultFromStorage(t *testing.T) { func TestShouldRefresh(t *testing.T) { // Get the current time to use for comparison now := time.Now() - client := fakeClient(t) tests := []struct { name string input time.Time @@ -473,7 +472,7 @@ func TestShouldRefresh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := client.shouldRefresh(tt.input) + result := shouldRefresh(tt.input) if result != tt.expected { t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected) } diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index 508ca579..a23f2a27 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strings" + "sync" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" @@ -62,10 +63,13 @@ func WithHTTPStatusCode(statusCode int) responseOption { // Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence. type Client struct { + mu sync.Mutex resp []response } func (c *Client) AppendResponse(opts ...responseOption) { + c.mu.Lock() + defer c.mu.Unlock() r := response{code: http.StatusOK, headers: http.Header{}} for _, o := range opts { o.apply(&r) @@ -74,6 +78,8 @@ func (c *Client) AppendResponse(opts ...responseOption) { } func (c *Client) Do(req *http.Request) (*http.Response, error) { + c.mu.Lock() + defer c.mu.Unlock() if len(c.resp) == 0 { panic(fmt.Sprintf(`no response for "%s"`, req.URL.String())) } diff --git a/apps/internal/mock/syncmock.go b/apps/internal/mock/syncmock.go deleted file mode 100644 index ec8d0360..00000000 --- a/apps/internal/mock/syncmock.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package mock - -import ( - "bytes" - "fmt" - "io" - "net/http" - "sync" -) - -// Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence. -type SyncClient struct { - mu sync.Mutex - resp []response -} - -func (c *SyncClient) AppendResponse(opts ...responseOption) { - c.mu.Lock() - defer c.mu.Unlock() - - r := response{code: http.StatusOK, headers: http.Header{}} - for _, o := range opts { - o.apply(&r) - } - c.resp = append(c.resp, r) -} - -func (c *SyncClient) Do(req *http.Request) (*http.Response, error) { - c.mu.Lock() - defer c.mu.Unlock() - - if len(c.resp) == 0 { - panic(fmt.Sprintf(`no response for "%s"`, req.URL.String())) - } - resp := c.resp[0] - c.resp = c.resp[1:] - if resp.callback != nil { - resp.callback(req) - } - res := http.Response{Header: resp.headers, StatusCode: resp.code} - res.Body = io.NopCloser(bytes.NewReader(resp.body)) - return &res, nil -} - -// CloseIdleConnections implements the comm.HTTPClient interface -func (*SyncClient) CloseIdleConnections() {} diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index a2a563d4..f75c9183 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -249,7 +249,7 @@ func New(id ID, options ...ClientOption) (Client, error) { default: return Client{}, fmt.Errorf("unsupported type %T", id) } - var zero atomic.Value = atomic.Value{} + zero := atomic.Value{} zero.Store(false) client := Client{ miType: id, diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 559f1294..07957e50 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1142,71 +1142,66 @@ func TestRefreshInMultipleRequests(t *testing.T) { miType := SystemAssigned() setEnvVars(t, CloudShell) - t.Run("Test for refresh multiple request", func(t *testing.T) { - originalTime := getCurrentTime - defer func() { - getCurrentTime = originalTime - }() - before := cacheManager - defer func() { cacheManager = before }() - cacheManager = storage.New(nil) - // Create a mock client and append mock responses - mockClient := mock.SyncClient{} - mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), - ) - // Create the client instance - client, err := New(miType, WithHTTPClient(&mockClient)) - if err != nil { - t.Fatal(err) - } - ar, err := client.AcquireToken(context.Background(), resource) - if err != nil { - t.Fatal(err) - } - // Assert the first token is returned - if ar.AccessToken != firstToken { - t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) - } + originalTime := getCurrentTime + defer func() { + getCurrentTime = originalTime + }() + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + // Create a mock client and append mock responses + mockClient := mock.Client{} + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + // Create the client instance + client, err := New(miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + ar, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Fatal(err) + } + // Assert the first token is returned + if ar.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) + } - fixedTime := time.Now().Add(time.Duration(43400) * time.Second) - getCurrentTime = func() time.Time { - return fixedTime - } - var wg sync.WaitGroup - requestChecker := false - - mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn+43200))), mock.WithCallback(func(req *http.Request) { - }), - ) - for i := 0; i < 10000; i++ { - wg.Add(1) - go func() { - defer wg.Done() - ar, err := client.AcquireToken(context.Background(), resource) - if err != nil { - t.Error(err) - return - } - if ar.AccessToken == secondToken && ar.Metadata.TokenSource == base.IdentityProvider { - if requestChecker { - t.Error("Error can only call this only once") - } else { - requestChecker = true - } - } - }() - } - // Waiting for all goroutines to finish + fixedTime := time.Now().Add(time.Duration(43400) * time.Second) + getCurrentTime = func() time.Time { + return fixedTime + } + var wg sync.WaitGroup + requestChecker := false + + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn+43200))), mock.WithCallback(func(req *http.Request) { + }), + ) + for i := 0; i < 10000; i++ { + wg.Add(1) go func() { - wg.Wait() - if !requestChecker { - t.Error("Error should be called at least once") + defer wg.Done() + ar, err := client.AcquireToken(context.Background(), resource) + if err != nil { + t.Error(err) + return + } + if ar.AccessToken == secondToken && ar.Metadata.TokenSource == base.IdentityProvider { + if requestChecker { + t.Error("Error can only call this only once") + } else { + requestChecker = true + } } }() - wg.Wait() - }) + } + wg.Wait() + if !requestChecker { + t.Error("Error should be called at least once") + } + } func TestShouldRefresh(t *testing.T) { From 45a995f7d512045a71fa4e635cb506296226dadd Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 21 Feb 2025 12:27:21 +0000 Subject: [PATCH 23/29] Added a time setting for refreshOn for MI --- apps/managedidentity/managedidentity.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index f75c9183..e27003df 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -331,6 +331,10 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac } ar, err := base.AuthResultFromStorage(storageTokenResponse) if err == nil { + timeUntilExpiry := time.Until(ar.ExpiresOn) + if timeUntilExpiry > (time.Hour * 2) { + ar.Metadata.RefreshOn = time.Now().Add(timeUntilExpiry / 2) + } if c.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { defer c.canRefresh.Store(false) if tr, er := c.getToken(ctx, resource); er == nil { From eabd5d3ac1b961154d38fd16ca3f47a331563867 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Fri, 21 Feb 2025 14:13:23 +0000 Subject: [PATCH 24/29] Updated the refreshon time when ests gives empry refreshon --- apps/managedidentity/managedidentity.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index e27003df..1589a9cf 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -331,10 +331,6 @@ func (c Client) AcquireToken(ctx context.Context, resource string, options ...Ac } ar, err := base.AuthResultFromStorage(storageTokenResponse) if err == nil { - timeUntilExpiry := time.Until(ar.ExpiresOn) - if timeUntilExpiry > (time.Hour * 2) { - ar.Metadata.RefreshOn = time.Now().Add(timeUntilExpiry / 2) - } if c.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { defer c.canRefresh.Store(false) if tr, er := c.getToken(ctx, resource); er == nil { @@ -468,6 +464,11 @@ func authResultFromToken(authParams authority.AuthParams, token accesstokens.Tok if err != nil { return base.AuthResult{}, err } + // if refreshOn is not set, set it to half of the time until expiry if expiry is more than 2 hours away + timeUntilExpiry := time.Until(token.ExpiresOn) + if token.RefreshOn.T.IsZero() && timeUntilExpiry > (time.Hour*2) { + token.RefreshOn.T = time.Now().Add(timeUntilExpiry / 2) + } ar, err := base.NewAuthResult(token, account) if err != nil { return base.AuthResult{}, err From 8e6d3ef1d8b618ae891cac2d6dd1b41626bcf1ea Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Tue, 25 Feb 2025 10:48:02 +0000 Subject: [PATCH 25/29] Updated test to fail on first error --- apps/confidential/confidential_test.go | 22 +++++++++++++------- apps/internal/mock/mock.go | 2 +- apps/managedidentity/managedidentity_test.go | 12 +++++++++-- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 6298dc92..9c7910e7 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -814,8 +814,6 @@ func TestNewCredFromTokenProvider(t *testing.T) { t.Fatalf(`unexpected token "%s"`, ar.AccessToken) } } - -//  go test -race -timeout 30s -run ^TestRefreshInMultipleRequests$ github.com/AzureAD/microsoft-authentication-library-for-go/apps/confideintial func TestRefreshInMultipleRequests(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { @@ -869,6 +867,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { } var wg sync.WaitGroup done := make(chan struct{}) + ch := make(chan error, 1) firstTenantChecker := false secondTenantChecker := false @@ -884,7 +883,10 @@ func TestRefreshInMultipleRequests(t *testing.T) { defer wg.Done() ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant")) if err != nil { - t.Error(err) + select { + case ch <- err: + default: + } return } if ar.AccessToken == secondToken+"firstTenant" && ar.Metadata.TokenSource == base.IdentityProvider { @@ -899,7 +901,10 @@ func TestRefreshInMultipleRequests(t *testing.T) { defer wg.Done() ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant")) if err != nil { - t.Error(err) + select { + case ch <- err: + default: + } return } if ar.AccessToken == secondToken+"secondTenant" && ar.Metadata.TokenSource == base.IdentityProvider { @@ -915,15 +920,18 @@ func TestRefreshInMultipleRequests(t *testing.T) { go func() { wg.Wait() close(done) + close(ch) }() - + select { + case err := <-ch: + t.Fatal(err) + default: + } // Wait for all goroutines to complete <-done - if !secondTenantChecker && !firstTenantChecker { t.Error("Error should be called at least once") } - } func TestRefreshIn(t *testing.T) { diff --git a/apps/internal/mock/mock.go b/apps/internal/mock/mock.go index a23f2a27..defe3a0e 100644 --- a/apps/internal/mock/mock.go +++ b/apps/internal/mock/mock.go @@ -105,7 +105,7 @@ func GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo string, e // Conditionally add the "refresh_in" field if refreshIn is provided if refreshIn > 0 { - body += fmt.Sprintf(`, "refresh_in": %d`, refreshIn) + body += fmt.Sprintf(`, "refresh_in":"%d"`, refreshIn) } // Add the optional fields if they are provided diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 07957e50..a236d827 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1174,7 +1174,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { } var wg sync.WaitGroup requestChecker := false - + ch := make(chan error, 1) mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn+43200))), mock.WithCallback(func(req *http.Request) { }), @@ -1185,7 +1185,10 @@ func TestRefreshInMultipleRequests(t *testing.T) { defer wg.Done() ar, err := client.AcquireToken(context.Background(), resource) if err != nil { - t.Error(err) + select { + case ch <- err: + default: + } return } if ar.AccessToken == secondToken && ar.Metadata.TokenSource == base.IdentityProvider { @@ -1198,6 +1201,11 @@ func TestRefreshInMultipleRequests(t *testing.T) { }() } wg.Wait() + select { + case err := <-ch: + t.Fatal(err) + default: + } if !requestChecker { t.Error("Error should be called at least once") } From 7a8eefebe2f00608e3edea96588380f5f5719dad Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Wed, 26 Feb 2025 11:12:32 +0000 Subject: [PATCH 26/29] Refactored the channel for test --- apps/confidential/confidential_test.go | 12 +++--------- apps/managedidentity/managedidentity_test.go | 2 +- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 9c7910e7..cd1b3020 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -866,7 +866,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { return fixedTime } var wg sync.WaitGroup - done := make(chan struct{}) + // done := make(chan struct{}) ch := make(chan error, 1) firstTenantChecker := false @@ -916,22 +916,16 @@ func TestRefreshInMultipleRequests(t *testing.T) { } }() } - // Wait for all goroutines in a separate goroutine - go func() { - wg.Wait() - close(done) - close(ch) - }() + wg.Wait() select { case err := <-ch: t.Fatal(err) default: } - // Wait for all goroutines to complete - <-done if !secondTenantChecker && !firstTenantChecker { t.Error("Error should be called at least once") } + close(ch) } func TestRefreshIn(t *testing.T) { diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index a236d827..3f3c6cad 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1209,7 +1209,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { if !requestChecker { t.Error("Error should be called at least once") } - + close(ch) } func TestShouldRefresh(t *testing.T) { From 5b6c74fc96202a5ee935468f53896bb34c32f0ae Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 27 Feb 2025 10:21:38 +0000 Subject: [PATCH 27/29] Resolve PR comments --- apps/confidential/confidential_test.go | 15 ++++---- apps/internal/base/base.go | 21 +++++------- apps/internal/base/base_test.go | 36 -------------------- apps/managedidentity/managedidentity.go | 17 ++++----- apps/managedidentity/managedidentity_test.go | 6 ++-- 5 files changed, 23 insertions(+), 72 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index cd1b3020..cb6cc687 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -825,9 +825,9 @@ func TestRefreshInMultipleRequests(t *testing.T) { refreshIn := 43200 expiresIn := 86400 - originalTime := base.GetCurrentTime + originalTime := base.Now defer func() { - base.GetCurrentTime = originalTime + base.Now = originalTime }() // Create a mock client and append mock responses mockClient := mock.Client{} @@ -862,13 +862,11 @@ func TestRefreshInMultipleRequests(t *testing.T) { t.Fatalf("wanted %q, got %q", firstToken, arSecond.AccessToken) } fixedTime := time.Now().Add(time.Duration(43400) * time.Second) - base.GetCurrentTime = func() time.Time { + base.Now = func() time.Time { return fixedTime } var wg sync.WaitGroup - // done := make(chan struct{}) ch := make(chan error, 1) - firstTenantChecker := false secondTenantChecker := false mockClient.AppendResponse( @@ -925,7 +923,6 @@ func TestRefreshInMultipleRequests(t *testing.T) { if !secondTenantChecker && !firstTenantChecker { t.Error("Error should be called at least once") } - close(ch) } func TestRefreshIn(t *testing.T) { @@ -952,9 +949,9 @@ func TestRefreshIn(t *testing.T) { } { name := "token doesn't need refresh" t.Run(name, func(t *testing.T) { - originalTime := base.GetCurrentTime + originalTime := base.Now defer func() { - base.GetCurrentTime = originalTime + base.Now = originalTime }() // Create a mock client and append mock responses mockClient := mock.Client{} @@ -994,7 +991,7 @@ func TestRefreshIn(t *testing.T) { t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshIn, ar.Metadata.RefreshOn.Second()) } fixedTime := time.Now().Add(time.Duration(tt.secondRequestAfter) * time.Second) - base.GetCurrentTime = func() time.Time { + base.Now = func() time.Time { return fixedTime } ar, err = client.AcquireTokenSilent(context.Background(), tokenScope) diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 9ae76870..7c1da5ba 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -169,7 +169,7 @@ type Client struct { cacheAccessor cache.ExportReplace cacheAccessorMu *sync.RWMutex canRefresh map[string]*atomic.Value - refreshMu *sync.Mutex + canRefreshMu *sync.Mutex } // Option is an optional argument to the New constructor. @@ -247,7 +247,7 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O manager: storage.New(token), pmanager: storage.NewPartitionedManager(token), canRefresh: make(map[string]*atomic.Value), - refreshMu: &sync.Mutex{}, + canRefreshMu: &sync.Mutex{}, } for _, o := range options { if err = o(&client); err != nil { @@ -352,15 +352,15 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen if silent.Claims == "" { ar, err = AuthResultFromStorage(storageTokenResponse) if err == nil { - if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) { - b.refreshMu.Lock() - refreshValue, exists := b.canRefresh[tenant] - if !exists { + if !storageTokenResponse.AccessToken.RefreshOn.T.IsZero() && Now().After(storageTokenResponse.AccessToken.RefreshOn.T) { + b.canRefreshMu.Lock() + refreshValue, ok := b.canRefresh[tenant] + if !ok { refreshValue = &atomic.Value{} refreshValue.Store(false) b.canRefresh[tenant] = refreshValue } - b.refreshMu.Unlock() + b.canRefreshMu.Unlock() if refreshValue.CompareAndSwap(false, true) { defer refreshValue.Store(false) if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { @@ -480,12 +480,7 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au // This function wraps time.Now() and is used for refreshing the application // was created to test the function against refreshin -var GetCurrentTime = time.Now - -// shouldRefresh returns true if the token should be refreshed. -func shouldRefresh(t time.Time) bool { - return !t.IsZero() && t.Before(GetCurrentTime()) -} +var Now = time.Now func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) { if b.cacheAccessor != nil { diff --git a/apps/internal/base/base_test.go b/apps/internal/base/base_test.go index e07ea110..3642f5e2 100644 --- a/apps/internal/base/base_test.go +++ b/apps/internal/base/base_test.go @@ -443,39 +443,3 @@ func TestAuthResultFromStorage(t *testing.T) { } } } - -// TestShouldRefresh tests the shouldRefresh function -func TestShouldRefresh(t *testing.T) { - // Get the current time to use for comparison - now := time.Now() - tests := []struct { - name string - input time.Time - expected bool - }{ - { - name: "Zero time", - input: time.Time{}, // Zero time - expected: false, // Should return false because it's zero time - }, - { - name: "Future time", - input: now.Add(time.Hour), // 1 hour in the future - expected: false, // Should return false because it's in the future - }, - { - name: "Past time", - input: now.Add(-time.Hour), // 1 hour in the past - expected: true, // Should return true because it's in the past - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := shouldRefresh(tt.input) - if result != tt.expected { - t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected) - } - }) - } -} diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 1589a9cf..956c0913 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -298,17 +298,11 @@ func GetSource() (Source, error) { // This function wraps time.Now() and is used for refreshing the application // was created to test the function against refreshin -var getCurrentTime = time.Now +var Now = time.Now // shouldRefresh returns true if the token should be refreshed. func (b Client) shouldRefresh(t time.Time) bool { - if t.IsZero() || t.After(getCurrentTime()) { - return false - } - if !b.canRefresh.CompareAndSwap(false, true) { - return false - } - return true + return !t.IsZero() && !t.After(Now()) && b.canRefresh.CompareAndSwap(false, true) } // Acquires tokens from the configured managed identity on an azure resource. @@ -465,9 +459,10 @@ func authResultFromToken(authParams authority.AuthParams, token accesstokens.Tok return base.AuthResult{}, err } // if refreshOn is not set, set it to half of the time until expiry if expiry is more than 2 hours away - timeUntilExpiry := time.Until(token.ExpiresOn) - if token.RefreshOn.T.IsZero() && timeUntilExpiry > (time.Hour*2) { - token.RefreshOn.T = time.Now().Add(timeUntilExpiry / 2) + if token.RefreshOn.T.IsZero() { + if lifetime := time.Until(token.ExpiresOn); lifetime > 2*time.Hour { + token.RefreshOn.T = time.Now().Add(lifetime / 2) + } } ar, err := base.NewAuthResult(token, account) if err != nil { diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 3f3c6cad..35b462af 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1142,9 +1142,9 @@ func TestRefreshInMultipleRequests(t *testing.T) { miType := SystemAssigned() setEnvVars(t, CloudShell) - originalTime := getCurrentTime + originalTime := Now defer func() { - getCurrentTime = originalTime + Now = originalTime }() before := cacheManager defer func() { cacheManager = before }() @@ -1169,7 +1169,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { } fixedTime := time.Now().Add(time.Duration(43400) * time.Second) - getCurrentTime = func() time.Time { + Now = func() time.Time { return fixedTime } var wg sync.WaitGroup From 9338f41fe681958b3fb5298980541821cd4c0e09 Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 27 Feb 2025 17:22:51 +0000 Subject: [PATCH 28/29] updated code based on comments --- apps/confidential/confidential_test.go | 16 ++++++---------- apps/managedidentity/managedidentity_test.go | 6 +----- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index cb6cc687..d4eb6f40 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -871,6 +871,10 @@ func TestRefreshInMultipleRequests(t *testing.T) { secondTenantChecker := false mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200))), + mock.WithCallback(func(r *http.Request) { + wg.Done() + time.Sleep(500 * time.Millisecond) + }), ) mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn+44200)))) @@ -888,11 +892,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { return } if ar.AccessToken == secondToken+"firstTenant" && ar.Metadata.TokenSource == base.IdentityProvider { - if firstTenantChecker { - t.Error("Error can only call this once") - } else { - firstTenantChecker = true - } + firstTenantChecker = true } }() go func() { @@ -906,11 +906,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { return } if ar.AccessToken == secondToken+"secondTenant" && ar.Metadata.TokenSource == base.IdentityProvider { - if secondTenantChecker { - t.Error("Error can only call this once") - } else { - secondTenantChecker = true - } + secondTenantChecker = true } }() } diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 35b462af..7298bab2 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -1192,11 +1192,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { return } if ar.AccessToken == secondToken && ar.Metadata.TokenSource == base.IdentityProvider { - if requestChecker { - t.Error("Error can only call this only once") - } else { - requestChecker = true - } + requestChecker = true } }() } From 3705439c4c6acc6ec539165cde546f330417c74d Mon Sep 17 00:00:00 2001 From: Nilesh Choudhary Date: Thu, 27 Feb 2025 19:11:12 +0000 Subject: [PATCH 29/29] Added a test to check the concurrent 2 tenant request --- apps/confidential/confidential_test.go | 97 ++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 6 deletions(-) diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index d4eb6f40..d08b0adb 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -870,12 +870,7 @@ func TestRefreshInMultipleRequests(t *testing.T) { firstTenantChecker := false secondTenantChecker := false mockClient.AppendResponse( - mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200))), - mock.WithCallback(func(r *http.Request) { - wg.Done() - time.Sleep(500 * time.Millisecond) - }), - ) + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200)))) mockClient.AppendResponse( mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn+44200)))) @@ -921,6 +916,96 @@ func TestRefreshInMultipleRequests(t *testing.T) { } } +func TestConcurrentRequests(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + firstToken := "first token" + secondToken := "new token " + lmo := "login.microsoftonline.com" + refreshIn := 43200 + expiresIn := 86400 + + originalTime := base.Now + defer func() { + base.Now = originalTime + }() + // Create a mock client and append mock responses + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "firstTenant"))) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "secondTenant"))) + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))), + ) + // Create the client instance + client, err := New(fmt.Sprintf(authorityFmt, lmo, "firstTenant"), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false)) + if err != nil { + t.Fatal(err) + } + // Acquire the first token for first tenant + ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("firstTenant")) + if err != nil { + t.Fatal(err) + } + // Assert the first token is returned + if ar.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken) + } + // Acquire the first token for second tenant + arSecond, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("secondTenant")) + if err != nil { + t.Fatal(err) + } + if arSecond.AccessToken != firstToken { + t.Fatalf("wanted %q, got %q", firstToken, arSecond.AccessToken) + } + fixedTime := time.Now().Add(time.Duration(43400) * time.Second) + base.Now = func() time.Time { + return fixedTime + } + var wg sync.WaitGroup + var wgComeplete sync.WaitGroup + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn+44200))), + ) + + mockClient.AppendResponse( + mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200))), + mock.WithCallback(func(r *http.Request) { + time.Sleep(2 * time.Second) + }), + ) + + wg.Add(1) + wgComeplete.Add(1) + go func() { + ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant")) + if err != nil { + t.Error("Unexpected error", err) + } + wg.Wait() + wgComeplete.Done() + if ar.AccessToken != secondToken+"firstTenant" { + t.Error("wanted first token, got second") + } + }() + go func() { + ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant")) + if err != nil { + t.Error("Unexpected error", err) + } + if ar.AccessToken != secondToken+"secondTenant" { + t.Error("wanted second token, got first") + } + wg.Done() + }() + wgComeplete.Wait() +} + func TestRefreshIn(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil {