Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for refresh_in for ConfidentialClient #542

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2021ef5
Added support for force refresh_in
4gust Jan 9, 2025
8d7a65b
fixed failing test
4gust Jan 9, 2025
566cdd2
updating storage
4gust Jan 13, 2025
78eeace
Updated the force refresh in
4gust Jan 14, 2025
de241bc
Updated test description
4gust Jan 15, 2025
4a38a84
Updated test
4gust Jan 15, 2025
0102a9e
Updated force refresh in
4gust Jan 16, 2025
10e7c6f
Updated some tests
4gust Jan 23, 2025
e95ece3
Updated time.
4gust Jan 23, 2025
91dd29f
Cleaned up code reference with PR comments
4gust Jan 29, 2025
bd448e8
Refactor code
4gust Feb 4, 2025
90e946a
Updated the refreshin system on per tenant base
4gust Feb 7, 2025
0425d56
Added test for force refresh once for each tenant
4gust Feb 11, 2025
d82d813
Update confidential_test.go
4gust Feb 11, 2025
d817f7d
Update confidential_test.go
4gust Feb 11, 2025
b955ede
Update confidential_test.go
4gust Feb 11, 2025
c5393f5
Refactor code
4gust Feb 11, 2025
2700b64
Update confidential_test.go
4gust Feb 13, 2025
87e17fa
Merge branch 'main' of https://github.com/AzureAD/microsoft-authentic…
4gust Feb 17, 2025
e35eb9d
Updated some tests to adapt to change in time
4gust Feb 17, 2025
1cea155
Added RefreshIn logic for Managed Identity
4gust Feb 19, 2025
e6a3b29
Added a sync http client and updated tests
4gust Feb 20, 2025
76218c5
Updated the code
4gust Feb 21, 2025
45a995f
Added a time setting for refreshOn for MI
4gust Feb 21, 2025
eabd5d3
Updated the refreshon time when ests gives empry refreshon
4gust Feb 21, 2025
8e6d3ef
Updated test to fail on first error
4gust Feb 25, 2025
7a8eefe
Refactored the channel for test
4gust Feb 26, 2025
5b6c74f
Resolve PR comments
4gust Feb 27, 2025
9338f41
updated code based on comments
4gust Feb 27, 2025
3705439
Added a test to check the concurrent 2 tenant request
4gust Feb 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ func (cca Client) AcquireTokenByUsernamePassword(ctx context.Context, scopes []s
if err != nil {
return AuthResult{}, err
}
return cca.base.AuthResultFromToken(ctx, authParams, token, true)
return cca.base.AuthResultFromToken(ctx, authParams, token)
}

// acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow.
Expand Down Expand Up @@ -731,7 +731,7 @@ func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string,
if err != nil {
return AuthResult{}, err
}
return cca.base.AuthResultFromToken(ctx, authParams, token, true)
return cca.base.AuthResultFromToken(ctx, authParams, token)
}

// acquireTokenOnBehalfOfOptions contains optional configuration for AcquireTokenOnBehalfOf
Expand Down
155 changes: 139 additions & 16 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/kylelemons/godebug/pretty"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
Expand Down Expand Up @@ -138,8 +139,9 @@ func TestAcquireTokenByCredential(t *testing.T) {
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
RefreshOn: internalTime.DurationTime{T: time.Now().Add(6 * time.Hour)},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(12 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(12 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
}, cred, fakeAuthority)
Expand Down Expand Up @@ -255,7 +257,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) {
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 86400, 43200)))

client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
Expand All @@ -278,7 +280,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) {
}
// new assertion should trigger new token request
token2 := token + "2"
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token2, "", "rt", "", 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token2, "", "rt", "", 86400, 43200)))
tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion+"2", tokenScope)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -432,7 +434,7 @@ func TestAcquireTokenSilentTenants(t *testing.T) {
t.Fatal("silent auth should fail because the cache is empty")
}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(tenant, "", "", "", 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(tenant, "", "", "", 3600, 0)))
if _, err := client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant)); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -708,7 +710,8 @@ func TestNewCredFromCertError(t *testing.T) {
func TestNewCredFromTokenProvider(t *testing.T) {
expectedToken := "expected token"
called := false
expiresIn := 4200
expiresIn := 18000
refreshOn := expiresIn / 2
key := struct{}{}
ctx := context.WithValue(context.Background(), key, true)
cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
Expand All @@ -728,6 +731,7 @@ func TestNewCredFromTokenProvider(t *testing.T) {
return exported.TokenProviderResult{
AccessToken: expectedToken,
ExpiresInSeconds: expiresIn,
RefreshInSeconds: refreshOn,
}, nil
})
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
Expand All @@ -741,9 +745,13 @@ func TestNewCredFromTokenProvider(t *testing.T) {
if !called {
t.Fatal("token provider wasn't invoked")
}
if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn {
t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v)
if !isTimeSame(ar.ExpiresOn, expiresIn) {
t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, ar.ExpiresOn.Second())
}
if !isTimeSame(ar.Metadata.RefreshOn, refreshOn) {
t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshOn, ar.Metadata.RefreshOn.Second())
}

if ar.AccessToken != expectedToken {
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
}
Expand All @@ -756,6 +764,94 @@ func TestNewCredFromTokenProvider(t *testing.T) {
}
}

func TestRefreshIn(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
firstToken := "first token"
secondToken := "new token"
lmo := "login.microsoftonline.com"
tenant := "tenant"
refreshIn := 43200
expiresIn := 86400
for _, tt := range []struct {
shouldGetNewToken bool
secondRequestAfter int
shouldReturnError bool
}{
{secondRequestAfter: 40000, shouldGetNewToken: false}, // from cache
{secondRequestAfter: 43400, shouldGetNewToken: true}, // refresh in expired so new token
{secondRequestAfter: 40000, shouldGetNewToken: false, shouldReturnError: true}, // refresh in not expired but refresh failed so new token
{secondRequestAfter: 80000, shouldGetNewToken: true, shouldReturnError: false}, // refresh in expired but refresh failed so new token
{secondRequestAfter: 1003400, shouldGetNewToken: true},
} {
name := "token doesn't need refresh"
t.Run(name, func(t *testing.T) {
originalTime := base.GetCurrentTime
defer func() {
base.GetCurrentTime = originalTime
}()
// Create a mock client and append mock responses
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))),
)
if tt.shouldReturnError {
mockClient.AppendResponse(
mock.WithHTTPStatusCode(http.StatusBadGateway),
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))),
)
} else {
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken, expiresIn, refreshIn))),
)
}

// Create the client instance
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
if err != nil {
t.Fatal(err)
}
// Acquire the first token
ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope)
if err != nil {
t.Fatal(err)
}
// Assert the first token is returned
if ar.AccessToken != firstToken {
t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken)
}
if ar.Metadata.RefreshOn.IsZero() {
t.Fatal("RefreshOn shouldn't be zero")
}
if !isTimeSame(ar.Metadata.RefreshOn, refreshIn) {
t.Fatalf("expected RefreshOn ~= %d seconds from now, got %d", refreshIn, ar.Metadata.RefreshOn.Second())
}
fixedTime := time.Now().Add(time.Duration(tt.secondRequestAfter) * time.Second)
base.GetCurrentTime = func() time.Time {
return fixedTime
}
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.Metadata.TokenSource != base.Cache && !tt.shouldGetNewToken {
t.Fatal("should have returned from cache.")
}
if (ar.AccessToken == secondToken) != tt.shouldGetNewToken {
t.Fatalf("wanted %q, got %q", secondToken, ar.AccessToken)
}
})
}
}

func isTimeSame(t time.Time, expectedSeconds int) bool {
v := int(time.Until(t).Seconds())
return !(v < expectedSeconds-2 || v > expectedSeconds+2)
}

func TestNewCredFromTokenProviderError(t *testing.T) {
expectedError := "something went wrong"
cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
Expand All @@ -771,6 +867,33 @@ func TestNewCredFromTokenProviderError(t *testing.T) {
}
}

func TestTokenProviderResultForRefreshIn(t *testing.T) {
accessToken, claims, tenant := "at", "claims", "tenant"
cred := NewCredFromTokenProvider(func(ctx context.Context, tpp TokenProviderParameters) (TokenProviderResult, error) {
if tpp.Claims != claims {
t.Fatalf(`unexpected claims "%s"`, tpp.Claims)
}
if tpp.TenantID != tenant {
t.Fatalf(`unexpected tenant "%s"`, tpp.TenantID)
}
return TokenProviderResult{AccessToken: accessToken, ExpiresInSeconds: 36000, RefreshInSeconds: 18000}, nil
})
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithClaims(claims), WithTenantID(tenant))
if err != nil {
t.Fatal(err)
}
if !isTimeSame(ar.Metadata.RefreshOn, 18000) {
t.Fatal("RefreshOn should be 18000 seconds from now")
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
}

func TestTokenProviderOptions(t *testing.T) {
accessToken, claims, tenant := "at", "claims", "tenant"
cred := NewCredFromTokenProvider(func(ctx context.Context, tpp TokenProviderParameters) (TokenProviderResult, error) {
Expand Down Expand Up @@ -820,7 +943,7 @@ func TestWithCache(t *testing.T) {
authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB)
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600, 0)))

cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
Expand Down Expand Up @@ -931,7 +1054,7 @@ func TestWithClaims(t *testing.T) {
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600)),
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600, 0)),
mock.WithCallback(func(r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -995,7 +1118,7 @@ func TestWithClaims(t *testing.T) {
// client has cached access and refresh tokens. When given claims, it should redeem a refresh token for a new access token.
newToken := "new-access-token"
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600)),
mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600, 0)),
mock.WithCallback(func(r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1055,7 +1178,7 @@ func TestWithTenantID(t *testing.T) {
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
client, err := New(test.authority, fakeClientID, cred, WithHTTPClient(&mockClient))
Expand Down Expand Up @@ -1112,7 +1235,7 @@ func TestWithTenantID(t *testing.T) {
otherTenant := "not-" + test.tenant
if method == "obo" {
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)))
if _, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(otherTenant)); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1150,7 +1273,7 @@ func TestWithTenantID(t *testing.T) {
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "", "", 3600)),
mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "", "", 3600, 0)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
if i == 0 {
Expand Down Expand Up @@ -1206,7 +1329,7 @@ func TestWithInstanceDiscovery(t *testing.T) {
}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(stackurl, tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)),
)
client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
if err != nil {
Expand Down Expand Up @@ -1268,7 +1391,7 @@ func TestWithPortAuthority(t *testing.T) {
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient))
Expand Down
Loading