Skip to content

Commit

Permalink
Add Managed Identity Support (#552)
Browse files Browse the repository at this point in the history
* Added Managed Identity support for multiple sources (IMDS, App Service, CloudShell, AzureML, Service Fabric, Azure Arc)
* Updated tests
* Updated documentation
* Added new Managed Identity client that currently supports cache and retry policies
  • Loading branch information
AndyOHart authored Feb 14, 2025
1 parent c4a7948 commit e6d9244
Show file tree
Hide file tree
Showing 34 changed files with 2,576 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
run: go build ./apps/...

- name: Unit Tests
run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/...
run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/... ./apps/managedidentity/...
# Intergration tests runs on ADO
# - name: Integration Tests
# run: go test -race ./apps/tests/integration/...
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ Acquiring tokens with MSAL Go follows this general pattern. There might be some
}
confidentialClient, err := confidential.New("https://login.microsoftonline.com/your_tenant", "client_id", cred)
```
* Initializing a Managed Identity client for SystemAssigned:

```go
import mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
// Managed identity client have a type of ID required, SystemAssigned or UserAssigned
miSystemAssigned, err := mi.New(mi.SystemAssigned())
if err != nil {
// TODO: handle error
}
```
* Initializing a Managed Identity client for UserAssigned:

```go
import mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
// Managed identity client have a type of ID required, SystemAssigned or UserAssigned
miSystemAssigned, err := mi.New(mi.UserAssignedClientID("YOUR_CLIENT_ID"))
if err != nil {
// TODO: handle error
}
```

1. Call `AcquireTokenSilent()` to look for a cached token. If `AcquireTokenSilent()` returns an error, call another `AcquireToken...` method to authenticate.

Expand Down Expand Up @@ -96,6 +118,16 @@ Acquiring tokens with MSAL Go follows this general pattern. There might be some
accessToken := result.AccessToken
```

* ManagedIdentity clietn can simply call `AcquireToken()`:
```go
resource := "<Your resource>"
result, err := miSystemAssigned.AcquireToken(context.TODO(), resource)
if err != nil {
// TODO: handle error
}
accessToken := result.AccessToken
```

## Community Help and Support

We use [Stack Overflow](http://stackoverflow.com/questions/tagged/msal) to work with the community on supporting Azure Active Directory and its SDKs, including this one! We highly recommend you ask your questions on Stack Overflow (we're all on there!) Also browse existing issues to see if someone has had your question before. Please use the "msal" tag when asking your questions.
Expand Down
57 changes: 46 additions & 11 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -25,6 +24,7 @@ import (
"github.com/kylelemons/godebug/pretty"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
Expand All @@ -35,6 +35,7 @@ import (

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
type errorClient struct{}
type contextKey struct{}

func (*errorClient) Do(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String())
Expand Down Expand Up @@ -138,7 +139,7 @@ func TestAcquireTokenByCredential(t *testing.T) {
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExpiresOn: time.Now().Add(1 * time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
Expand Down Expand Up @@ -305,7 +306,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) {

func TestAcquireTokenByAssertionCallback(t *testing.T) {
calls := 0
key := struct{}{}
key := contextKey{}
ctx := context.WithValue(context.Background(), key, true)
getAssertion := func(c context.Context, o AssertionRequestOptions) (string, error) {
if v := c.Value(key); v == nil || !v.(bool) {
Expand Down Expand Up @@ -358,7 +359,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
tr := accesstokens.TokenResponse{
AccessToken: token,
RefreshToken: refresh,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExpiresOn: time.Now().Add(1 * time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
IDToken: accesstokens.IDToken{
Expand Down Expand Up @@ -427,6 +428,40 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
}
}

func TestInvalidJsonErrFromResponse(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
tenant := "A"
lmo := "login.microsoftonline.com"
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
// cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it.
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
body := fmt.Sprintf(
`{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`,
tenant, 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(),
)
mockClient.AppendResponse(mock.WithBody([]byte(body)))
_, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant))
if err == nil {
t.Fatal("should have failed with InvalidJsonErr Response")
}
var ie errors.InvalidJsonErr
if !errors.As(err, &ie) {
t.Fatal("should have revieved a InvalidJsonErr, but got", err)
}
}

func TestAcquireTokenSilentTenants(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
Expand Down Expand Up @@ -478,7 +513,7 @@ func TestADFSTokenCaching(t *testing.T) {
AccessToken: "at1",
RefreshToken: "rt",
TokenType: "bearer",
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: time.Now().Add(time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
IDToken: accesstokens.IDToken{
Expand Down Expand Up @@ -608,7 +643,7 @@ func TestNewCredFromCert(t *testing.T) {
t.Run(fmt.Sprintf("%s/%v", filepath.Base(file.path), sendX5c), func(t *testing.T) {
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: time.Now().Add(time.Hour),
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
}, cred, fakeAuthority, opts...)
if err != nil {
Expand Down Expand Up @@ -724,7 +759,7 @@ func TestNewCredFromTokenProvider(t *testing.T) {
expectedToken := "expected token"
called := false
expiresIn := 4200
key := struct{}{}
key := contextKey{}
ctx := context.WithValue(context.Background(), key, true)
cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
if called {
Expand Down Expand Up @@ -982,7 +1017,7 @@ func TestWithClaims(t *testing.T) {
case "password":
ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims))
default:
t.Fatalf("test bug: no test for " + method)
t.Fatalf("test bug: no test for %s", method)
}
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1092,7 +1127,7 @@ func TestWithTenantID(t *testing.T) {
case "obo":
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant))
default:
t.Fatalf("test bug: no test for " + method)
t.Fatalf("test bug: no test for %s", method)
}
if err != nil {
if test.expectError {
Expand Down Expand Up @@ -1402,7 +1437,7 @@ func TestWithAuthenticationScheme(t *testing.T) {
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExpiresOn: time.Now().Add(1 * time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "TokenType",
Expand Down Expand Up @@ -1442,7 +1477,7 @@ func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExpiresOn: time.Now().Add(1 * time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
Expand Down
9 changes: 9 additions & 0 deletions apps/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,20 @@ type CallErr struct {
Err error
}

type InvalidJsonErr struct {
Err error
}

// Errors implements error.Error().
func (e CallErr) Error() string {
return e.Err.Error()
}

// Errors implements error.Error().
func (e InvalidJsonErr) Error() string {
return e.Err.Error()
}

// Verbose prints a versbose error message with the request or response.
func (e CallErr) Verbose() string {
e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need
Expand Down
5 changes: 2 additions & 3 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
Expand Down Expand Up @@ -111,7 +111,6 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu
if err := storageTokenResponse.AccessToken.Validate(); err != nil {
return AuthResult{}, fmt.Errorf("problem with access token in StorageTokenResponse: %w", err)
}

account := storageTokenResponse.Account
accessToken := storageTokenResponse.AccessToken.Secret
grantedScopes := strings.Split(storageTokenResponse.AccessToken.Scopes, scopeSeparator)
Expand Down Expand Up @@ -146,7 +145,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
Account: account,
IDToken: tokenResponse.IDToken,
AccessToken: tokenResponse.AccessToken,
ExpiresOn: tokenResponse.ExpiresOn.T,
ExpiresOn: tokenResponse.ExpiresOn,
GrantedScopes: tokenResponse.GrantedScopes.Slice,
Metadata: AuthResultMetadata{
TokenSource: IdentityProvider,
Expand Down
12 changes: 6 additions & 6 deletions apps/internal/base/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
Expand Down Expand Up @@ -50,7 +50,7 @@ func fakeClient(t *testing.T, opts ...Option) Client {
client.Token.AccessTokens = &fake.AccessTokens{
AccessToken: accesstokens.TokenResponse{
AccessToken: fakeAccessToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: time.Now().Add(time.Hour),
FamilyID: "family-id",
GrantedScopes: accesstokens.Scopes{Slice: testScopes},
IDToken: fakeIDToken,
Expand Down Expand Up @@ -135,7 +135,7 @@ func TestAcquireTokenSilentScopes(t *testing.T) {
accesstokens.TokenResponse{
AccessToken: fakeAccessToken,
ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)},
ExpiresOn: time.Now().Add(-time.Hour),
GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
Expand Down Expand Up @@ -178,7 +178,7 @@ func TestAcquireTokenSilentGrantedScopes(t *testing.T) {
},
accesstokens.TokenResponse{
AccessToken: expectedToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: time.Now().Add(time.Hour),
GrantedScopes: accesstokens.Scopes{Slice: grantedScopes},
TokenType: "Bearer",
},
Expand Down Expand Up @@ -335,7 +335,7 @@ func TestCreateAuthenticationResult(t *testing.T) {
desc: "no declined scopes",
input: accesstokens.TokenResponse{
AccessToken: "accessToken",
ExpiresOn: internalTime.DurationTime{T: future},
ExpiresOn: future,
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
DeclinedScopes: nil,
},
Expand All @@ -353,7 +353,7 @@ func TestCreateAuthenticationResult(t *testing.T) {
desc: "declined scopes",
input: accesstokens.TokenResponse{
AccessToken: "accessToken",
ExpiresOn: internalTime.DurationTime{T: future},
ExpiresOn: future,
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
DeclinedScopes: []string{"openid"},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, ex

// Key outputs the key that can be used to uniquely look up this entry in a map.
func (a AccessToken) Key() string {
ks := []string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes}
key := strings.Join(
[]string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes},
ks,
shared.CacheKeySeparator,
)
// add token type to key for new access tokens types. skip for bearer token type to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func TestContractUnmarshalJSON(t *testing.T) {
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestContractUnmarshalJSON: -want/+got:\n%s", diff)
t.Errorf(string(got.AdditionalFields["unknownEntity"].(stdJSON.RawMessage)))
t.Errorf("%s", string(got.AdditionalFields["unknownEntity"].(stdJSON.RawMessage)))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes
realm,
clientID,
cachedAt,
tokenResponse.ExpiresOn.T,
tokenResponse.ExpiresOn,
tokenResponse.ExtExpiresOn.T,
target,
tokenResponse.AccessToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"testing"
"time"

internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
Expand Down Expand Up @@ -59,7 +58,7 @@ func TestOBOAccessTokenScopes(t *testing.T) {
accesstokens.TokenResponse{
AccessToken: scope[0] + "-at",
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: time.Now().Add(time.Hour),
GrantedScopes: accesstokens.Scopes{Slice: scope},
IDToken: idt,
RefreshToken: upn + "-rt",
Expand Down Expand Up @@ -121,7 +120,7 @@ func TestOBOPartitioning(t *testing.T) {
accesstokens.TokenResponse{
AccessToken: upn + "-at",
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: time.Now().Add(time.Hour),
GrantedScopes: accesstokens.Scopes{Slice: scopes},
IDToken: idt,
RefreshToken: upn + "-rt",
Expand Down
Loading

0 comments on commit e6d9244

Please sign in to comment.