Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for refresh_in for ConfidentialClient #542

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2021ef5
Added support for force refresh_in
4gust Jan 9, 2025
8d7a65b
fixed failing test
4gust Jan 9, 2025
566cdd2
updating storage
4gust Jan 13, 2025
78eeace
Updated the force refresh in
4gust Jan 14, 2025
de241bc
Updated test description
4gust Jan 15, 2025
4a38a84
Updated test
4gust Jan 15, 2025
0102a9e
Updated force refresh in
4gust Jan 16, 2025
10e7c6f
Updated some tests
4gust Jan 23, 2025
e95ece3
Updated time.
4gust Jan 23, 2025
91dd29f
Cleaned up code reference with PR comments
4gust Jan 29, 2025
bd448e8
Refactor code
4gust Feb 4, 2025
90e946a
Updated the refreshin system on per tenant base
4gust Feb 7, 2025
0425d56
Added test for force refresh once for each tenant
4gust Feb 11, 2025
d82d813
Update confidential_test.go
4gust Feb 11, 2025
d817f7d
Update confidential_test.go
4gust Feb 11, 2025
b955ede
Update confidential_test.go
4gust Feb 11, 2025
c5393f5
Refactor code
4gust Feb 11, 2025
2700b64
Update confidential_test.go
4gust Feb 13, 2025
87e17fa
Merge branch 'main' of https://github.com/AzureAD/microsoft-authentic…
4gust Feb 17, 2025
e35eb9d
Updated some tests to adapt to change in time
4gust Feb 17, 2025
1cea155
Added RefreshIn logic for Managed Identity
4gust Feb 19, 2025
e6a3b29
Added a sync http client and updated tests
4gust Feb 20, 2025
76218c5
Updated the code
4gust Feb 21, 2025
45a995f
Added a time setting for refreshOn for MI
4gust Feb 21, 2025
eabd5d3
Updated the refreshon time when ests gives empry refreshon
4gust Feb 21, 2025
8e6d3ef
Updated test to fail on first error
4gust Feb 25, 2025
7a8eefe
Refactored the channel for test
4gust Feb 26, 2025
5b6c74f
Resolve PR comments
4gust Feb 27, 2025
9338f41
updated code based on comments
4gust Feb 27, 2025
3705439
Added a test to check the concurrent 2 tenant request
4gust Feb 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ func (cca Client) AcquireTokenByUsernamePassword(ctx context.Context, scopes []s
if err != nil {
return AuthResult{}, err
}
return cca.base.AuthResultFromToken(ctx, authParams, token, true)
return cca.base.AuthResultFromToken(ctx, authParams, token)
}

// acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow.
Expand Down Expand Up @@ -733,7 +733,7 @@ func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string,
if err != nil {
return AuthResult{}, err
}
return cca.base.AuthResultFromToken(ctx, authParams, token, true)
return cca.base.AuthResultFromToken(ctx, authParams, token)
}

// acquireTokenOnBehalfOfOptions contains optional configuration for AcquireTokenOnBehalfOf
Expand Down
269 changes: 252 additions & 17 deletions apps/confidential/confidential_test.go

Large diffs are not rendered by default.

46 changes: 38 additions & 8 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ package base

import (
"context"
"errors"
"fmt"
"net/url"
"reflect"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
Expand Down Expand Up @@ -94,6 +95,7 @@ type AuthResult struct {

// AuthResultMetadata which contains meta data for the AuthResult
type AuthResultMetadata struct {
RefreshOn time.Time
TokenSource TokenSource
}

Expand Down Expand Up @@ -132,6 +134,7 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu
DeclinedScopes: nil,
Metadata: AuthResultMetadata{
TokenSource: Cache,
RefreshOn: storageTokenResponse.AccessToken.RefreshOn.T,
},
}, nil
}
Expand All @@ -149,6 +152,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
GrantedScopes: tokenResponse.GrantedScopes.Slice,
Metadata: AuthResultMetadata{
TokenSource: IdentityProvider,
RefreshOn: tokenResponse.RefreshOn.T,
},
}, nil
}
Expand All @@ -164,6 +168,8 @@ type Client struct {
AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New().
cacheAccessor cache.ExportReplace
cacheAccessorMu *sync.RWMutex
canRefresh map[string]*atomic.Value
refreshMu *sync.Mutex
}

// Option is an optional argument to the New constructor.
Expand Down Expand Up @@ -240,6 +246,8 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O
cacheAccessorMu: &sync.RWMutex{},
manager: storage.New(token),
pmanager: storage.NewPartitionedManager(token),
canRefresh: make(map[string]*atomic.Value),
refreshMu: &sync.Mutex{},
}
for _, o := range options {
if err = o(&client); err != nil {
Expand Down Expand Up @@ -344,6 +352,22 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
if silent.Claims == "" {
ar, err = AuthResultFromStorage(storageTokenResponse)
if err == nil {
if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) {
b.refreshMu.Lock()
refreshValue, exists := b.canRefresh[tenant]
if !exists {
refreshValue = &atomic.Value{}
refreshValue.Store(false)
b.canRefresh[tenant] = refreshValue
}
b.refreshMu.Unlock()
if refreshValue.CompareAndSwap(false, true) {
defer refreshValue.Store(false)
if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil {
return b.AuthResultFromToken(ctx, authParams, tr)
}
}
}
ar.AccessToken, err = authParams.AuthnScheme.FormatAccessToken(ar.AccessToken)
return ar, err
}
Expand All @@ -361,7 +385,7 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
if err != nil {
return ar, err
}
return b.AuthResultFromToken(ctx, authParams, token, true)
return b.AuthResultFromToken(ctx, authParams, token)
}

func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) {
Expand Down Expand Up @@ -390,7 +414,7 @@ func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams Acqui
return AuthResult{}, err
}

return b.AuthResultFromToken(ctx, authParams, token, true)
return b.AuthResultFromToken(ctx, authParams, token)
}

// AcquireTokenOnBehalfOf acquires a security token for an app using middle tier apps access token.
Expand Down Expand Up @@ -419,15 +443,12 @@ func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams Acq
authParams.UserAssertion = onBehalfOfParams.UserAssertion
token, err := b.Token.OnBehalfOf(ctx, authParams, onBehalfOfParams.Credential)
if err == nil {
ar, err = b.AuthResultFromToken(ctx, authParams, token, true)
ar, err = b.AuthResultFromToken(ctx, authParams, token)
}
return ar, err
}

func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) {
if !cacheWrite {
return NewAuthResult(token, shared.Account{})
}
func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse) (AuthResult, error) {
var m manager = b.manager
if authParams.AuthorizationType == authority.ATOnBehalfOf {
m = b.pmanager
Expand Down Expand Up @@ -457,6 +478,15 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au
return ar, err
}

// This function wraps time.Now() and is used for refreshing the application
// was created to test the function against refreshin
var GetCurrentTime = time.Now

// shouldRefresh returns true if the token should be refreshed.
func shouldRefresh(t time.Time) bool {
return !t.IsZero() && t.Before(GetCurrentTime())
}

func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) {
if b.cacheAccessor != nil {
b.cacheAccessorMu.RLock()
Expand Down
39 changes: 37 additions & 2 deletions apps/internal/base/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func TestCacheIOErrors(t *testing.T) {
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
_, actual = client.AuthResultFromToken(ctx, authority.AuthParams{AuthnScheme: &authority.BearerAuthenticationScheme{}}, accesstokens.TokenResponse{}, true)
_, actual = client.AuthResultFromToken(ctx, authority.AuthParams{AuthnScheme: &authority.BearerAuthenticationScheme{}}, accesstokens.TokenResponse{})
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
Expand Down Expand Up @@ -284,7 +284,6 @@ func TestCacheIOErrors(t *testing.T) {
IDToken: fakeIDToken,
RefreshToken: "rt",
},
true,
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -444,3 +443,39 @@ func TestAuthResultFromStorage(t *testing.T) {
}
}
}

// TestShouldRefresh tests the shouldRefresh function
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need this comment

func TestShouldRefresh(t *testing.T) {
// Get the current time to use for comparison
now := time.Now()
tests := []struct {
name string
input time.Time
expected bool
}{
{
name: "Zero time",
input: time.Time{}, // Zero time
expected: false, // Should return false because it's zero time
},
{
name: "Future time",
input: now.Add(time.Hour), // 1 hour in the future
expected: false, // Should return false because it's in the future
},
{
name: "Past time",
input: now.Add(-time.Hour), // 1 hour in the past
expected: true, // Should return true because it's in the past
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := shouldRefresh(tt.input)
if result != tt.expected {
t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected)
}
})
}
}
4 changes: 3 additions & 1 deletion apps/internal/base/storage/items.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type AccessToken struct {
ClientID string `json:"client_id,omitempty"`
Secret string `json:"secret,omitempty"`
Scopes string `json:"target,omitempty"`
RefreshOn internalTime.Unix `json:"refresh_on,omitempty"`
ExpiresOn internalTime.Unix `json:"expires_on,omitempty"`
ExtendedExpiresOn internalTime.Unix `json:"extended_expires_on,omitempty"`
CachedAt internalTime.Unix `json:"cached_at,omitempty"`
Expand All @@ -83,7 +84,7 @@ type AccessToken struct {
}

// NewAccessToken is the constructor for AccessToken.
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken {
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, refreshOn, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken {
return AccessToken{
HomeAccountID: homeID,
Environment: env,
Expand All @@ -93,6 +94,7 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, ex
Secret: token,
Scopes: scopes,
CachedAt: internalTime.Unix{T: cachedAt.UTC()},
RefreshOn: internalTime.Unix{T: refreshOn.UTC()},
ExpiresOn: internalTime.Unix{T: expiresOn.UTC()},
ExtendedExpiresOn: internalTime.Unix{T: extendedExpiresOn.UTC()},
TokenType: tokenType,
Expand Down
3 changes: 3 additions & 0 deletions apps/internal/base/storage/items_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ var (
)

func TestCreateAccessToken(t *testing.T) {

testExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testRefreshOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testExtExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testCachedAt := time.Date(2020, time.June, 13, 11, 0, 0, 0, time.UTC)
actualAt := NewAccessToken("testHID",
"env",
"realm",
"clientID",
testCachedAt,
testRefreshOn,
testExpiresOn,
testExtExpiresOn,
"user.read",
Expand Down
1 change: 1 addition & 0 deletions apps/internal/base/storage/partitioned_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes
realm,
clientID,
cachedAt,
tokenResponse.RefreshOn.T,
tokenResponse.ExpiresOn,
tokenResponse.ExtExpiresOn.T,
target,
Expand Down
2 changes: 2 additions & 0 deletions apps/internal/base/storage/partitioned_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func TestReadPartitionedAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid user.read",
"secret",
"Bearer",
Expand Down Expand Up @@ -210,6 +211,7 @@ func TestWritePartitionedAccessToken(t *testing.T) {
now,
now,
now,
now,
"openid",
"secret",
"tokenType",
Expand Down
1 change: 1 addition & 0 deletions apps/internal/base/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces
realm,
clientID,
cachedAt,
tokenResponse.RefreshOn.T,
tokenResponse.ExpiresOn,
tokenResponse.ExtExpiresOn.T,
target,
Expand Down
Loading
Loading