diff --git a/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs b/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs index cb441baa80..0861401a62 100644 --- a/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs +++ b/src/client/Microsoft.Identity.Client/Internal/Requests/ManagedIdentityAuthRequest.cs @@ -152,8 +152,9 @@ private async Task SendTokenRequestForManagedIdentityAsync await ResolveAuthorityAsync().ConfigureAwait(false); - ManagedIdentityClient managedIdentityClient = - new ManagedIdentityClient(AuthenticationRequestParameters.RequestContext); + ManagedIdentityClient managedIdentityClient = + await ManagedIdentityClient.CreateAsync(AuthenticationRequestParameters.RequestContext, cancellationToken) + .ConfigureAwait(false); ManagedIdentityResponse managedIdentityResponse = await managedIdentityClient diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/CredentialManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/CredentialManagedIdentitySource.cs new file mode 100644 index 0000000000..b8c3dde5d5 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/CredentialManagedIdentitySource.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Net.Http; +using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Internal; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + internal class CredentialManagedIdentitySource : AbstractManagedIdentity + { + /// + /// Factory method to create an instance of `CredentialManagedIdentitySource`. + /// + public static AbstractManagedIdentity Create(RequestContext requestContext) + { + requestContext.Logger.Info(() => "[Managed Identity] Using credential based managed identity."); + + return new CredentialManagedIdentitySource(requestContext); + } + + private CredentialManagedIdentitySource(RequestContext requestContext) : + base(requestContext, ManagedIdentitySource.Credential) + { + } + + /// + /// Even though the Credential flow does not use this request, we need to satisfy the abstract contract. + /// Return a minimal, valid ManagedIdentityRequest using the fixed credential endpoint. + /// + /// The resource identifier (ignored in this flow). + /// A ManagedIdentityRequest instance using the credential endpoint. + protected override ManagedIdentityRequest CreateRequest(string resource) + { + // Return a minimal request with the fixed credential endpoint. + return new ManagedIdentityRequest( + HttpMethod.Post, + new Uri("http://169.254.169.254/metadata/identity/credential?cred-api-version=1.0")); + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsCredentialProbeManager.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsCredentialProbeManager.cs new file mode 100644 index 0000000000..0d6febaad4 --- /dev/null +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ImdsCredentialProbeManager.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Http; + +namespace Microsoft.Identity.Client.ManagedIdentity +{ + internal class ImdsCredentialProbeManager + { + private const string CredentialEndpoint = "http://169.254.169.254/metadata/identity/credential"; + private const string ProbeBody = "."; + private const string ImdsHeader = "IMDS/"; + private readonly IHttpManager _httpManager; + private readonly ILoggerAdapter _logger; + + public ImdsCredentialProbeManager(IHttpManager httpManager, ILoggerAdapter logger) + { + _httpManager = httpManager ?? throw new ArgumentNullException(nameof(httpManager)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + public async Task ExecuteAsync(CancellationToken cancellationToken = default) + { + _logger.Info("[Credential Probe] Initiating probe to IMDS credential endpoint."); + + var request = new ManagedIdentityRequest(HttpMethod.Post, new Uri($"{CredentialEndpoint}?cred-api-version=1.0")) + { + Content = ProbeBody + }; + + HttpContent httpContent = request.CreateHttpContent(); + + _logger.Info($"[Credential Probe] Sending request to {CredentialEndpoint}"); + _logger.Verbose(() => $"[Credential Probe] Request Headers: {string.Join(", ", request.Headers)}"); + _logger.Verbose(() => $"[Credential Probe] Request Body: {ProbeBody}"); + + try + { + HttpResponse response = await _httpManager.SendRequestAsync( + request.ComputeUri(), + request.Headers, + httpContent, + request.Method, + _logger, + doNotThrow: true, + mtlsCertificate: null, + customHttpClient: null, + cancellationToken).ConfigureAwait(false); + + LogResponseDetails(response); + + return EvaluateProbeResponse(response); + } + catch (Exception ex) + { + _logger.Error($"[Credential Probe] Exception during probe: {ex.Message}"); + _logger.Error($"[Credential Probe] Stack Trace: {ex.StackTrace}"); + return false; + } + } + + private void LogResponseDetails(HttpResponse response) + { + if (response == null) + { + _logger.Error("[Credential Probe] No response received from the server."); + return; + } + + _logger.Info($"[Credential Probe] Response Status Code: {response.StatusCode}"); + _logger.Verbose(() => $"[Credential Probe] Response Headers: {string.Join(", ", response.HeadersAsDictionary)}"); + + if (response.Body != null) + { + _logger.Verbose(() => $"[Credential Probe] Response Body: {response.Body}"); + } + } + + private bool EvaluateProbeResponse(HttpResponse response) + { + if (response == null) + { + _logger.Error("[Credential Probe] No response received from the server."); + return false; + } + + _logger.Info($"[Credential Probe] Evaluating response from credential endpoint. Status Code: {response.StatusCode}"); + + if (response.HeadersAsDictionary.TryGetValue("Server", out string serverHeader) && + serverHeader.StartsWith(ImdsHeader, StringComparison.OrdinalIgnoreCase)) + { + _logger.Info($"[Credential Probe] Credential endpoint supported. Server Header: {serverHeader}"); + return true; + } + + _logger.Warning($"[Credential Probe] Credential endpoint not supported. Status Code: {response.StatusCode}"); + return false; + } + } +} diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs index 80a45bb0da..f97775f711 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityClient.cs @@ -9,6 +9,7 @@ using Microsoft.Identity.Client.PlatformsCommon.Shared; using System.IO; using Microsoft.Identity.Client.Core; +using Microsoft.Identity.Client.Http; namespace Microsoft.Identity.Client.ManagedIdentity { @@ -19,14 +20,39 @@ internal class ManagedIdentityClient { private const string WindowsHimdsFilePath = "%Programfiles%\\AzureConnectedMachineAgent\\himds.exe"; private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds"; + + // Cache for the managed identity source + private static ManagedIdentitySource? s_cachedManagedIdentitySource; + private static readonly SemaphoreSlim s_credentialSemaphore = new(1, 1); private readonly AbstractManagedIdentity _identitySource; - public ManagedIdentityClient(RequestContext requestContext) + internal static async Task CreateAsync(RequestContext requestContext, CancellationToken cancellationToken = default) { - using (requestContext.Logger.LogMethodDuration()) + if (requestContext == null) { - _identitySource = SelectManagedIdentitySource(requestContext); + throw new ArgumentNullException(nameof(requestContext), "RequestContext cannot be null."); } + + requestContext.Logger?.Info("[ManagedIdentityClient] Creating ManagedIdentityClient."); + + AbstractManagedIdentity identitySource = await SelectManagedIdentitySourceAsync(requestContext, cancellationToken).ConfigureAwait(false); + + requestContext.Logger?.Info($"[ManagedIdentityClient] Managed identity source selected: {identitySource.GetType().Name}."); + + return new ManagedIdentityClient(identitySource); + } + + private ManagedIdentityClient(AbstractManagedIdentity identitySource) + { + _identitySource = identitySource ?? throw new ArgumentNullException(nameof(identitySource), "Identity source cannot be null."); + } + + /// + /// Resets the cached managed identity source. Used only for testing purposes. + /// + internal static void ResetManagedIdentitySourceCache() + { + s_cachedManagedIdentitySource = null; } internal Task SendTokenRequestForManagedIdentityAsync(AcquireTokenForManagedIdentityParameters parameters, CancellationToken cancellationToken) @@ -34,23 +60,34 @@ internal Task SendTokenRequestForManagedIdentityAsync(A return _identitySource.AuthenticateAsync(parameters, cancellationToken); } - // This method tries to create managed identity source for different sources, if none is created then defaults to IMDS. - private static AbstractManagedIdentity SelectManagedIdentitySource(RequestContext requestContext) + /// + /// This method tries to create managed identity source for different sources. + /// If none is created then defaults to IMDS. + /// + /// + /// + /// + private static async Task SelectManagedIdentitySourceAsync(RequestContext requestContext, CancellationToken cancellationToken = default) { - return GetManagedIdentitySource(requestContext.Logger) switch + ManagedIdentitySource source = await GetManagedIdentitySourceAsync(requestContext.ServiceBundle, cancellationToken).ConfigureAwait(false); + + return source switch { ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext), ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext), ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(requestContext), ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext), ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext), + ManagedIdentitySource.Credential => CredentialManagedIdentitySource.Create(requestContext), _ => new ImdsManagedIdentitySource(requestContext) }; } - // Detect managed identity source based on the availability of environment variables. - // The result of this method is not cached because reading environment variables is cheap. - // This method is perf sensitive any changes should be benchmarked. + /// + /// Compute the managed identity source based on the environment variables. + /// + /// + /// internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter logger = null) { string identityEndpoint = EnvironmentVariables.IdentityEndpoint; @@ -97,6 +134,128 @@ internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter lo } } + /// + /// Compute the managed identity source based on the environment variables and the probe. + /// + /// + /// + /// + /// + public static async Task GetManagedIdentitySourceAsync( + IServiceBundle serviceBundle, + CancellationToken cancellationToken = default) + { + if (serviceBundle == null) + { + throw new ArgumentNullException(nameof(serviceBundle), "ServiceBundle is required to initialize the probe manager."); + } + + ILoggerAdapter logger = serviceBundle.ApplicationLogger; + + logger.Verbose(() => s_cachedManagedIdentitySource.HasValue + ? "[Managed Identity] Using cached managed identity source." + : "[Managed Identity] Computing managed identity source asynchronously."); + + if (s_cachedManagedIdentitySource.HasValue) + { + return s_cachedManagedIdentitySource.Value; + } + + // Use SemaphoreSlim to prevent multiple threads from computing at the same time + logger.Verbose(() => "[Managed Identity] Entering managed identity source semaphore."); + await s_credentialSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + logger.Verbose(() => "[Managed Identity] Entered managed identity source semaphore."); + + try + { + // Ensure another thread didn't set this while waiting on semaphore + if (s_cachedManagedIdentitySource.HasValue) + { + return s_cachedManagedIdentitySource.Value; + } + + // Initialize probe manager + var probeManager = new ImdsCredentialProbeManager( + serviceBundle.HttpManager, + serviceBundle.ApplicationLogger); + + // Compute the managed identity source + s_cachedManagedIdentitySource = await ComputeManagedIdentitySourceAsync( + probeManager, + serviceBundle.ApplicationLogger, + cancellationToken).ConfigureAwait(false); + + logger.Info($"[Managed Identity] Managed identity source determined: {s_cachedManagedIdentitySource.Value}."); + + return s_cachedManagedIdentitySource.Value; + } + finally + { + s_credentialSemaphore.Release(); + logger.Verbose(() => "[Managed Identity] Released managed identity source semaphore."); + } + } + + /// + /// Compute the managed identity source based on the environment variables and the probe. + /// + /// + /// + /// + /// + private static async Task ComputeManagedIdentitySourceAsync( + ImdsCredentialProbeManager imdsCredentialProbeManager, + ILoggerAdapter logger, + CancellationToken cancellationToken) + { + string identityEndpoint = EnvironmentVariables.IdentityEndpoint; + string identityHeader = EnvironmentVariables.IdentityHeader; + string identityServerThumbprint = EnvironmentVariables.IdentityServerThumbprint; + string msiEndpoint = EnvironmentVariables.MsiEndpoint; + string imdsEndpoint = EnvironmentVariables.ImdsEndpoint; + string msiSecretMachineLearning = EnvironmentVariables.MsiSecret; + + if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader)) + { + if (!string.IsNullOrEmpty(identityServerThumbprint)) + { + return ManagedIdentitySource.ServiceFabric; + } + else + { + return ManagedIdentitySource.AppService; + } + } + else if (!string.IsNullOrEmpty(msiSecretMachineLearning) && !string.IsNullOrEmpty(msiEndpoint)) + { + return ManagedIdentitySource.MachineLearning; + } + else if (!string.IsNullOrEmpty(msiEndpoint)) + { + return ManagedIdentitySource.CloudShell; + } + else if (ValidateAzureArcEnvironment(identityEndpoint, imdsEndpoint, logger)) + { + return ManagedIdentitySource.AzureArc; + } + else + { + logger?.Info("[Managed Identity] Probing for credential endpoint."); + bool isSuccess = await imdsCredentialProbeManager.ExecuteAsync(cancellationToken).ConfigureAwait(false); + + if (isSuccess) + { + logger?.Info("[Managed Identity] Credential endpoint detected."); + return ManagedIdentitySource.Credential; + } + else + { + logger?.Verbose(() => "[Managed Identity] Defaulting to IMDS as credential endpoint not detected."); + return ManagedIdentitySource.DefaultToImds; + } + } + } + // Method to return true if a file exists and is not empty to validate the Azure arc environment. private static bool ValidateAzureArcEnvironment(string identityEndpoint, string imdsEndpoint, ILoggerAdapter logger) { @@ -118,8 +277,8 @@ private static bool ValidateAzureArcEnvironment(string identityEndpoint, string { logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is available through file detection."); return true; - } - + } + logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is not available."); return false; } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs index 6eb5a5bba0..6a5bb8fb44 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentityRequest.cs @@ -22,6 +22,7 @@ internal class ManagedIdentityRequest public IDictionary BodyParameters { get; } public IDictionary QueryParameters { get; } + public string Content { get; internal set; } public ManagedIdentityRequest(HttpMethod method, Uri endpoint) { @@ -39,5 +40,21 @@ public Uri ComputeUri() return uriBuilder.Uri; } + + public HttpContent CreateHttpContent() + { + if (!string.IsNullOrEmpty(Content)) + { + return new StringContent(Content, Encoding.UTF8, "application/json"); + } + + if (BodyParameters.Count > 0) + { + var formData = new FormUrlEncodedContent(BodyParameters); + return formData; + } + + return null; // No body content + } } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs index 69e3471bdf..6263aa6402 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ManagedIdentitySource.cs @@ -53,6 +53,12 @@ public enum ManagedIdentitySource /// /// The source to acquire token for managed identity is Machine Learning Service. /// - MachineLearning + MachineLearning, + + /// + /// Indicates that the source is credential endpoint based on the probe. + /// This is used to detect the new managed identity credential source. + /// + Credential } } diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs index eded64dc91..f7ef217839 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentityApplication.cs @@ -36,12 +36,16 @@ internal ManagedIdentityApplication( AppTokenCacheInternal = configuration.AppTokenCacheInternalForTest ?? new TokenCache(ServiceBundle, true); - this.ServiceBundle.ApplicationLogger.Verbose(()=>$"ManagedIdentityApplication {configuration.GetHashCode()} created"); + s_serviceBundle = this.ServiceBundle; + + s_serviceBundle.ApplicationLogger.Verbose(()=>$"ManagedIdentityApplication {configuration.GetHashCode()} created"); } // Stores all app tokens internal ITokenCacheInternal AppTokenCacheInternal { get; } + private static IServiceBundle s_serviceBundle; + /// public AcquireTokenForManagedIdentityParameterBuilder AcquireTokenForManagedIdentity(string resource) { @@ -63,5 +67,14 @@ public static ManagedIdentitySource GetManagedIdentitySource() { return ManagedIdentityClient.GetManagedIdentitySource(); } + + /// + /// Detects and returns the managed identity source available on the environment asynchronously. + /// + /// Managed identity source detected on the environment if any. + public static async Task GetManagedIdentitySourceAsync(CancellationToken cancellationToken = default) + { + return await ManagedIdentityClient.GetManagedIdentitySourceAsync(s_serviceBundle, cancellationToken).ConfigureAwait(false); + } } } diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt index e69de29bb2..9f2731d05f 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net462/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.Credential = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt index e69de29bb2..9f2731d05f 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net472/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.Credential = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt index e69de29bb2..9f2731d05f 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-android/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.Credential = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt index e69de29bb2..9f2731d05f 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0-ios/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.Credential = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt index e69de29bb2..9f2731d05f 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/net8.0/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.Credential = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task diff --git a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt index e69de29bb2..9f2731d05f 100644 --- a/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/client/Microsoft.Identity.Client/PublicApi/netstandard2.0/PublicAPI.Unshipped.txt @@ -0,0 +1,2 @@ +Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource.Credential = 8 -> Microsoft.Identity.Client.ManagedIdentity.ManagedIdentitySource +static Microsoft.Identity.Client.ManagedIdentityApplication.GetManagedIdentitySourceAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task diff --git a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs index 91e5c3d268..5171ac466b 100644 --- a/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs +++ b/tests/Microsoft.Identity.Test.Common/Core/Mocks/MockHelpers.cs @@ -582,5 +582,21 @@ public static MsalTokenResponse CreateMsalRunTimeBrokerTokenResponse(string acce TokenSource = TokenSource.Broker }; } + + public static void AddCredentialEndpointNotFoundHandlers( + ManagedIdentitySource managedIdentitySource, + MockHttpManager httpManager, + int count = 4) + { + if (managedIdentitySource != ManagedIdentitySource.Imds) + { + return; // Only add handlers for IMDS + } + + for (int i = 0; i < count; i++) + { + httpManager.AddMockHandlerContentNotFound(HttpMethod.Post); + } + } } } diff --git a/tests/Microsoft.Identity.Test.Common/TestCommon.cs b/tests/Microsoft.Identity.Test.Common/TestCommon.cs index 95772c3da7..64c1076b9b 100644 --- a/tests/Microsoft.Identity.Test.Common/TestCommon.cs +++ b/tests/Microsoft.Identity.Test.Common/TestCommon.cs @@ -31,6 +31,7 @@ using Microsoft.Identity.Test.Common.Core.Mocks; using NSubstitute; using static Microsoft.Identity.Client.TelemetryCore.Internal.Events.ApiEvent; +using Microsoft.Identity.Client.ManagedIdentity; namespace Microsoft.Identity.Test.Common { @@ -45,6 +46,7 @@ public static void ResetInternalStaticCaches() OidcRetrieverWithCache.ResetCacheForTest(); AuthorityManager.ClearValidationCache(); SingletonThrottlingManager.GetInstance().ResetCache(); + ManagedIdentityClient.ResetManagedIdentitySourceCache(); } public static object GetPropValue(object src, string propName) diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CredentialTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CredentialTests.cs new file mode 100644 index 0000000000..cedc6d18af --- /dev/null +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/CredentialTests.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Diagnostics; +using System.Net; +using System.Net.Http; +using System.Net.Sockets; +using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.AppConfig; +using Microsoft.Identity.Client.Internal; +using Microsoft.Identity.Client.ManagedIdentity; +using Microsoft.Identity.Client.TelemetryCore.Internal.Events; +using Microsoft.Identity.Test.Common; +using Microsoft.Identity.Test.Common.Core.Helpers; +using Microsoft.Identity.Test.Common.Core.Mocks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using OpenTelemetry.Resources; +using static Microsoft.Identity.Test.Common.Core.Helpers.ManagedIdentityTestUtil; + +namespace Microsoft.Identity.Test.Unit.ManagedIdentityTests +{ + + [TestClass] + public class CredentialTests : TestBase + { + private const string ImdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"; + internal const string Resource = "https://management.azure.com"; + internal const string CredentialEndpoint = "http://169.254.169.254/metadata/identity/credential"; + internal const string MtlsEndpoint = "https://centraluseuap.mtlsauth.microsoft.com/" + + "72f988bf-86f1-41af-91ab-2d7cd011db47/oauth2/v2.0/token"; + + [TestInitialize] + public override void TestInitialize() + { + TestCommon.ResetInternalStaticCaches(); + } + + [TestMethod] + public async Task CredentialSourceFailedFallbackToImdsTestAsync() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager(isManagedIdentity: true)) + { + SetEnvironmentVariables(ManagedIdentitySource.Imds, ImdsEndpoint); + + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager); + + // Disabling shared cache options to avoid cross test pollution. + miBuilder.Config.AccessorOptions = null; + + var mi = miBuilder.Build(); + + MockHelpers.AddCredentialEndpointNotFoundHandlers(ManagedIdentitySource.Imds, httpManager); + + httpManager.AddManagedIdentityMockHandler( + ImdsEndpoint, + Resource, + MockHelpers.GetMsiSuccessfulResponse(), + ManagedIdentitySource.Imds); + + var result = await mi.AcquireTokenForManagedIdentity(Resource).ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource); + + result = await mi.AcquireTokenForManagedIdentity(Resource) + .ExecuteAsync().ConfigureAwait(false); + + Assert.IsNotNull(result); + Assert.IsNotNull(result.AccessToken); + Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource); + } + } + } +} diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs index 7d117f7de7..653fc1626d 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ImdsTests.cs @@ -36,6 +36,8 @@ public async Task ImdsErrorHandlingTestAsync(HttpStatusCode statusCode, string e var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(ManagedIdentitySource.Imds, httpManager); + // Adding multiple mock handlers to simulate retries for GatewayTimeout for (int i = 0; i < expectedAttempts; i++) { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index 9ae8c9ce50..4803e6b88c 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -91,7 +91,9 @@ public async Task ManagedIdentityHappyPathAsync( miBuilder.Config.AccessorOptions = null; var mi = miBuilder.Build(); - + + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -143,6 +145,8 @@ public async Task ManagedIdentityUserAssignedHappyPathAsync( IManagedIdentityApplication mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -192,6 +196,8 @@ public async Task ManagedIdentityDifferentScopesTestAsync( var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -252,6 +258,8 @@ public async Task ManagedIdentityForceRefreshTestAsync( var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -314,6 +322,8 @@ public async Task ManagedIdentityWithClaimsAndCapabilitiesTestAsync( var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -375,6 +385,8 @@ public async Task ManagedIdentityWithClaimsTestAsync( miBuilder.Config.AccessorOptions = null; var mi = miBuilder.Build(); + + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); httpManager.AddManagedIdentityMockHandler( endpoint, @@ -554,6 +566,8 @@ public async Task ManagedIdentityErrorResponseNoPayloadTestAsync(ManagedIdentity var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler(endpoint, "scope", "", managedIdentitySource, statusCode: HttpStatusCode.InternalServerError); httpManager.AddManagedIdentityMockHandler(endpoint, "scope", "", @@ -596,6 +610,8 @@ public async Task ManagedIdentityNullResponseAsync(ManagedIdentitySource managed var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -636,6 +652,8 @@ public async Task ManagedIdentityUnreachableNetworkAsync(ManagedIdentitySource m var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddFailingRequest(new HttpRequestException("A socket operation was attempted to an unreachable network.", new SocketException(10051))); @@ -676,6 +694,8 @@ public async Task ManagedIdentityTestRetryAsync(ManagedIdentitySource managedIde var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler( endpoint, Resource, @@ -1114,6 +1134,8 @@ public async Task InvalidJsonResponseHandling(ManagedIdentitySource managedIdent var mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(managedIdentitySource, httpManager); + httpManager.AddManagedIdentityMockHandler( endpoint, "scope", @@ -1157,6 +1179,8 @@ public async Task ManagedIdentityRequestTokensForDifferentScopesTestAsync( IManagedIdentityApplication mi = miBuilder.Build(); + MockHelpers.AddCredentialEndpointNotFoundHandlers(source, httpManager); + // Mock handler for the initial resource request httpManager.AddManagedIdentityMockHandler(endpoint, initialResource, MockHelpers.GetMsiSuccessfulResponse(), source); @@ -1301,5 +1325,36 @@ public async Task MixedUserAndSystemAssignedManagedIdentityTestAsync() Assert.AreEqual(SystemAssignedClientId, systemAssignedTokens[0].ClientId, "System-assigned ClientId mismatch in cache."); } } + + [TestMethod] + public async Task ManagedIdentitySourceCachingWorksAsExpectedAsync() + { + // Create ApplicationConfiguration and ServiceBundle + var config = new ApplicationConfiguration(MsalClientType.ManagedIdentityClient); + var serviceBundle = new ServiceBundle(config); + + using (new EnvVariableContext()) + { + // Set environment variables to resolve as ServiceFabric + Environment.SetEnvironmentVariable("IDENTITY_ENDPOINT", "http://localhost:40342/metadata/identity/oauth2/token"); + Environment.SetEnvironmentVariable("IDENTITY_HEADER", "dummy-header"); + Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", "dummy-thumbprint"); + + // First call to populate the cache + ManagedIdentitySource firstSource = await ManagedIdentityClient.GetManagedIdentitySourceAsync(serviceBundle).ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.ServiceFabric, firstSource, "Initial resolution failed to detect ServiceFabric."); + + // Change environment variables to mimic a different identity source (shouldn't affect cached value) + Environment.SetEnvironmentVariable("IDENTITY_ENDPOINT", "http://localhost/other/metadata/identity/oauth2/token"); + Environment.SetEnvironmentVariable("IDENTITY_SERVER_THUMBPRINT", null); + + // Second call should return the cached value + ManagedIdentitySource cachedSource = await ManagedIdentityClient.GetManagedIdentitySourceAsync(serviceBundle).ConfigureAwait(false); + Assert.AreEqual(ManagedIdentitySource.ServiceFabric, cachedSource, "Cached value should remain as ServiceFabric."); + + // Ensure the cache was not recomputed + Assert.AreEqual(firstSource, cachedSource, "The cache was unexpectedly recomputed."); + } + } } } diff --git a/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs b/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs index 427b7ca149..cdcd872a54 100644 --- a/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs +++ b/tests/devapps/Managed Identity apps/ManagedIdentityAppVM/Program.cs @@ -7,9 +7,15 @@ IIdentityLogger identityLogger = new IdentityLogger(); -IManagedIdentityApplication mi = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) - .WithLogging(identityLogger, true) - .Build(); +var managedIdentitySource = await ManagedIdentityApplication. + GetManagedIdentitySourceAsync().ConfigureAwait(false); + +Console.WriteLine($"Managed identity source detected: {managedIdentitySource}"); + +IManagedIdentityApplication mi = ManagedIdentityApplicationBuilder + .Create(ManagedIdentityId.SystemAssigned) + .WithLogging(identityLogger, true) + .Build(); string? scope = "https://management.azure.com";