-
Notifications
You must be signed in to change notification settings - Fork 353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Credential Probe logic #5116
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
using System; | ||
using System.Net.Http; | ||
using Microsoft.Identity.Client.Core; | ||
using Microsoft.Identity.Client.Internal; | ||
|
||
namespace Microsoft.Identity.Client.ManagedIdentity | ||
{ | ||
internal class CredentialManagedIdentitySource : AbstractManagedIdentity | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe we just call it SLC or MSIv2 ? Credential is such an overloaded term. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, should it be |
||
{ | ||
/// <summary> | ||
/// Factory method to create an instance of `CredentialManagedIdentitySource`. | ||
/// </summary> | ||
public static AbstractManagedIdentity Create(RequestContext requestContext) | ||
{ | ||
requestContext.Logger.Info(() => "[Managed Identity] Using credential based managed identity."); | ||
|
||
return new CredentialManagedIdentitySource(requestContext); | ||
} | ||
|
||
private CredentialManagedIdentitySource(RequestContext requestContext) : | ||
base(requestContext, ManagedIdentitySource.Credential) | ||
{ | ||
} | ||
|
||
/// <summary> | ||
/// Even though the Credential flow does not use this request, we need to satisfy the abstract contract. | ||
/// Return a minimal, valid ManagedIdentityRequest using the fixed credential endpoint. | ||
/// </summary> | ||
/// <param name="resource">The resource identifier (ignored in this flow).</param> | ||
/// <returns>A ManagedIdentityRequest instance using the credential endpoint.</returns> | ||
protected override ManagedIdentityRequest CreateRequest(string resource) | ||
{ | ||
// Return a minimal request with the fixed credential endpoint. | ||
return new ManagedIdentityRequest( | ||
HttpMethod.Post, | ||
new Uri("http://169.254.169.254/metadata/identity/credential?cred-api-version=1.0")); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
using System; | ||
using System.Net; | ||
using System.Net.Http; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using Microsoft.Identity.Client.Core; | ||
using Microsoft.Identity.Client.Http; | ||
|
||
namespace Microsoft.Identity.Client.ManagedIdentity | ||
{ | ||
internal class ImdsCredentialProbeManager | ||
{ | ||
private const string CredentialEndpoint = "http://169.254.169.254/metadata/identity/credential"; | ||
private const string ProbeBody = "."; | ||
private const string ImdsHeader = "IMDS/"; | ||
private readonly IHttpManager _httpManager; | ||
private readonly ILoggerAdapter _logger; | ||
|
||
public ImdsCredentialProbeManager(IHttpManager httpManager, ILoggerAdapter logger) | ||
{ | ||
_httpManager = httpManager ?? throw new ArgumentNullException(nameof(httpManager)); | ||
_logger = logger ?? throw new ArgumentNullException(nameof(logger)); | ||
} | ||
|
||
public async Task<bool> ExecuteAsync(CancellationToken cancellationToken = default) | ||
{ | ||
_logger.Info("[Credential Probe] Initiating probe to IMDS credential endpoint."); | ||
|
||
var request = new ManagedIdentityRequest(HttpMethod.Post, new Uri($"{CredentialEndpoint}?cred-api-version=1.0")) | ||
{ | ||
Content = ProbeBody | ||
}; | ||
|
||
HttpContent httpContent = request.CreateHttpContent(); | ||
|
||
_logger.Info($"[Credential Probe] Sending request to {CredentialEndpoint}"); | ||
_logger.Verbose(() => $"[Credential Probe] Request Headers: {string.Join(", ", request.Headers)}"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a little dangerous to log headers and body. Maybe don't do it unless we really need it. |
||
_logger.Verbose(() => $"[Credential Probe] Request Body: {ProbeBody}"); | ||
|
||
try | ||
{ | ||
HttpResponse response = await _httpManager.SendRequestAsync( | ||
request.ComputeUri(), | ||
request.Headers, | ||
httpContent, | ||
request.Method, | ||
_logger, | ||
doNotThrow: true, | ||
mtlsCertificate: null, | ||
customHttpClient: null, | ||
cancellationToken).ConfigureAwait(false); | ||
|
||
LogResponseDetails(response); | ||
|
||
return EvaluateProbeResponse(response); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it worth recording in telemetry about this? I guess we will know MSAL version and MSI source already - will this be enough? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
catch (Exception ex) | ||
{ | ||
_logger.Error($"[Credential Probe] Exception during probe: {ex.Message}"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not really an error is it? It is part of the normal flow. It will confuse ppl. We've did this before in regional and ppl complained that the logs are too noisy. Pls use info level. I also recommend you do not print the full exception here, just the message should be suficient. |
||
_logger.Error($"[Credential Probe] Stack Trace: {ex.StackTrace}"); | ||
return false; | ||
} | ||
} | ||
|
||
private void LogResponseDetails(HttpResponse response) | ||
{ | ||
if (response == null) | ||
{ | ||
_logger.Error("[Credential Probe] No response received from the server."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not an error ? |
||
return; | ||
} | ||
|
||
_logger.Info($"[Credential Probe] Response Status Code: {response.StatusCode}"); | ||
_logger.Verbose(() => $"[Credential Probe] Response Headers: {string.Join(", ", response.HeadersAsDictionary)}"); | ||
|
||
if (response.Body != null) | ||
{ | ||
_logger.Verbose(() => $"[Credential Probe] Response Body: {response.Body}"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avoid printing entire reponse body. it could have PII or worse, some secrets. |
||
} | ||
} | ||
|
||
private bool EvaluateProbeResponse(HttpResponse response) | ||
{ | ||
if (response == null) | ||
{ | ||
_logger.Error("[Credential Probe] No response received from the server."); | ||
return false; | ||
} | ||
|
||
_logger.Info($"[Credential Probe] Evaluating response from credential endpoint. Status Code: {response.StatusCode}"); | ||
|
||
if (response.HeadersAsDictionary.TryGetValue("Server", out string serverHeader) && | ||
serverHeader.StartsWith(ImdsHeader, StringComparison.OrdinalIgnoreCase)) | ||
{ | ||
_logger.Info($"[Credential Probe] Credential endpoint supported. Server Header: {serverHeader}"); | ||
return true; | ||
} | ||
|
||
_logger.Warning($"[Credential Probe] Credential endpoint not supported. Status Code: {response.StatusCode}"); | ||
return false; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
using Microsoft.Identity.Client.PlatformsCommon.Shared; | ||
using System.IO; | ||
using Microsoft.Identity.Client.Core; | ||
using Microsoft.Identity.Client.Http; | ||
|
||
namespace Microsoft.Identity.Client.ManagedIdentity | ||
{ | ||
|
@@ -19,38 +20,74 @@ internal class ManagedIdentityClient | |
{ | ||
private const string WindowsHimdsFilePath = "%Programfiles%\\AzureConnectedMachineAgent\\himds.exe"; | ||
private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds"; | ||
|
||
// Cache for the managed identity source | ||
private static ManagedIdentitySource? s_cachedManagedIdentitySource; | ||
private static readonly SemaphoreSlim s_credentialSemaphore = new(1, 1); | ||
private readonly AbstractManagedIdentity _identitySource; | ||
|
||
public ManagedIdentityClient(RequestContext requestContext) | ||
internal static async Task<ManagedIdentityClient> CreateAsync(RequestContext requestContext, CancellationToken cancellationToken = default) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe the requestContext has the cancellation token, but it's ok to move it as last param. But I recommend you do not use a default. |
||
{ | ||
using (requestContext.Logger.LogMethodDuration()) | ||
if (requestContext == null) | ||
{ | ||
_identitySource = SelectManagedIdentitySource(requestContext); | ||
throw new ArgumentNullException(nameof(requestContext), "RequestContext cannot be null."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't be null. Don't throw. |
||
} | ||
|
||
requestContext.Logger?.Info("[ManagedIdentityClient] Creating ManagedIdentityClient."); | ||
|
||
AbstractManagedIdentity identitySource = await SelectManagedIdentitySourceAsync(requestContext, cancellationToken).ConfigureAwait(false); | ||
|
||
requestContext.Logger?.Info($"[ManagedIdentityClient] Managed identity source selected: {identitySource.GetType().Name}."); | ||
|
||
return new ManagedIdentityClient(identitySource); | ||
} | ||
|
||
private ManagedIdentityClient(AbstractManagedIdentity identitySource) | ||
{ | ||
_identitySource = identitySource ?? throw new ArgumentNullException(nameof(identitySource), "Identity source cannot be null."); | ||
} | ||
|
||
/// <summary> | ||
/// Resets the cached managed identity source. Used only for testing purposes. | ||
/// </summary> | ||
internal static void ResetManagedIdentitySourceCache() | ||
{ | ||
s_cachedManagedIdentitySource = null; | ||
} | ||
|
||
internal Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityAsync(AcquireTokenForManagedIdentityParameters parameters, CancellationToken cancellationToken) | ||
{ | ||
return _identitySource.AuthenticateAsync(parameters, cancellationToken); | ||
} | ||
|
||
// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS. | ||
private static AbstractManagedIdentity SelectManagedIdentitySource(RequestContext requestContext) | ||
/// <summary> | ||
/// This method tries to create managed identity source for different sources. | ||
/// If none is created then defaults to IMDS. | ||
/// </summary> | ||
/// <param name="requestContext"></param> | ||
/// <param name="cancellationToken"></param> | ||
/// <returns></returns> | ||
private static async Task<AbstractManagedIdentity> SelectManagedIdentitySourceAsync(RequestContext requestContext, CancellationToken cancellationToken = default) | ||
{ | ||
return GetManagedIdentitySource(requestContext.Logger) switch | ||
ManagedIdentitySource source = await GetManagedIdentitySourceAsync(requestContext.ServiceBundle, cancellationToken).ConfigureAwait(false); | ||
|
||
return source switch | ||
{ | ||
ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext), | ||
ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext), | ||
ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(requestContext), | ||
ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext), | ||
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext), | ||
ManagedIdentitySource.Credential => CredentialManagedIdentitySource.Create(requestContext), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naming: avoid using "credential". Maybe "ImdsV2" or smth |
||
_ => new ImdsManagedIdentitySource(requestContext) | ||
}; | ||
} | ||
|
||
// Detect managed identity source based on the availability of environment variables. | ||
// The result of this method is not cached because reading environment variables is cheap. | ||
// This method is perf sensitive any changes should be benchmarked. | ||
/// <summary> | ||
/// Compute the managed identity source based on the environment variables. | ||
/// </summary> | ||
/// <param name="logger"></param> | ||
/// <returns></returns> | ||
internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter logger = null) | ||
{ | ||
string identityEndpoint = EnvironmentVariables.IdentityEndpoint; | ||
|
@@ -97,6 +134,128 @@ internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter lo | |
} | ||
} | ||
|
||
/// <summary> | ||
/// Compute the managed identity source based on the environment variables and the probe. | ||
/// </summary> | ||
/// <param name="serviceBundle"></param> | ||
/// <param name="cancellationToken"></param> | ||
/// <returns></returns> | ||
/// <exception cref="ArgumentNullException"></exception> | ||
public static async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Logic is too complex. Why do we have 2 public methods - one with the probe and one without the probe? |
||
IServiceBundle serviceBundle, | ||
CancellationToken cancellationToken = default) | ||
{ | ||
if (serviceBundle == null) | ||
{ | ||
throw new ArgumentNullException(nameof(serviceBundle), "ServiceBundle is required to initialize the probe manager."); | ||
} | ||
|
||
ILoggerAdapter logger = serviceBundle.ApplicationLogger; | ||
|
||
logger.Verbose(() => s_cachedManagedIdentitySource.HasValue | ||
? "[Managed Identity] Using cached managed identity source." | ||
: "[Managed Identity] Computing managed identity source asynchronously."); | ||
|
||
if (s_cachedManagedIdentitySource.HasValue) | ||
{ | ||
return s_cachedManagedIdentitySource.Value; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't Azure SDK need a way to reset this? For their tests? Maybe it's time we expose a "ResetCachesForTest()" method in the extensibility namespace. |
||
} | ||
|
||
// Use SemaphoreSlim to prevent multiple threads from computing at the same time | ||
logger.Verbose(() => "[Managed Identity] Entering managed identity source semaphore."); | ||
await s_credentialSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); | ||
logger.Verbose(() => "[Managed Identity] Entered managed identity source semaphore."); | ||
|
||
try | ||
{ | ||
// Ensure another thread didn't set this while waiting on semaphore | ||
if (s_cachedManagedIdentitySource.HasValue) | ||
{ | ||
return s_cachedManagedIdentitySource.Value; | ||
} | ||
|
||
// Initialize probe manager | ||
var probeManager = new ImdsCredentialProbeManager( | ||
serviceBundle.HttpManager, | ||
serviceBundle.ApplicationLogger); | ||
|
||
// Compute the managed identity source | ||
s_cachedManagedIdentitySource = await ComputeManagedIdentitySourceAsync( | ||
probeManager, | ||
serviceBundle.ApplicationLogger, | ||
cancellationToken).ConfigureAwait(false); | ||
|
||
logger.Info($"[Managed Identity] Managed identity source determined: {s_cachedManagedIdentitySource.Value}."); | ||
|
||
return s_cachedManagedIdentitySource.Value; | ||
} | ||
finally | ||
{ | ||
s_credentialSemaphore.Release(); | ||
logger.Verbose(() => "[Managed Identity] Released managed identity source semaphore."); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Compute the managed identity source based on the environment variables and the probe. | ||
/// </summary> | ||
/// <param name="imdsCredentialProbeManager"></param> | ||
/// <param name="logger"></param> | ||
/// <param name="cancellationToken"></param> | ||
/// <returns></returns> | ||
private static async Task<ManagedIdentitySource> ComputeManagedIdentitySourceAsync( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please have a look at all these method names and consolidate. They are so many "select source", "compute source", etc. |
||
ImdsCredentialProbeManager imdsCredentialProbeManager, | ||
ILoggerAdapter logger, | ||
CancellationToken cancellationToken) | ||
{ | ||
string identityEndpoint = EnvironmentVariables.IdentityEndpoint; | ||
string identityHeader = EnvironmentVariables.IdentityHeader; | ||
string identityServerThumbprint = EnvironmentVariables.IdentityServerThumbprint; | ||
string msiEndpoint = EnvironmentVariables.MsiEndpoint; | ||
string imdsEndpoint = EnvironmentVariables.ImdsEndpoint; | ||
string msiSecretMachineLearning = EnvironmentVariables.MsiSecret; | ||
|
||
if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader)) | ||
{ | ||
if (!string.IsNullOrEmpty(identityServerThumbprint)) | ||
{ | ||
return ManagedIdentitySource.ServiceFabric; | ||
} | ||
else | ||
{ | ||
return ManagedIdentitySource.AppService; | ||
} | ||
} | ||
else if (!string.IsNullOrEmpty(msiSecretMachineLearning) && !string.IsNullOrEmpty(msiEndpoint)) | ||
{ | ||
return ManagedIdentitySource.MachineLearning; | ||
} | ||
else if (!string.IsNullOrEmpty(msiEndpoint)) | ||
{ | ||
return ManagedIdentitySource.CloudShell; | ||
} | ||
else if (ValidateAzureArcEnvironment(identityEndpoint, imdsEndpoint, logger)) | ||
{ | ||
return ManagedIdentitySource.AzureArc; | ||
} | ||
else | ||
{ | ||
logger?.Info("[Managed Identity] Probing for credential endpoint."); | ||
bool isSuccess = await imdsCredentialProbeManager.ExecuteAsync(cancellationToken).ConfigureAwait(false); | ||
|
||
if (isSuccess) | ||
{ | ||
logger?.Info("[Managed Identity] Credential endpoint detected."); | ||
return ManagedIdentitySource.Credential; | ||
} | ||
else | ||
{ | ||
logger?.Verbose(() => "[Managed Identity] Defaulting to IMDS as credential endpoint not detected."); | ||
return ManagedIdentitySource.DefaultToImds; | ||
} | ||
} | ||
} | ||
|
||
// Method to return true if a file exists and is not empty to validate the Azure arc environment. | ||
private static bool ValidateAzureArcEnvironment(string identityEndpoint, string imdsEndpoint, ILoggerAdapter logger) | ||
{ | ||
|
@@ -118,8 +277,8 @@ private static bool ValidateAzureArcEnvironment(string identityEndpoint, string | |
{ | ||
logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is available through file detection."); | ||
return true; | ||
} | ||
} | ||
|
||
logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is not available."); | ||
return false; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gladjohn - we discussed that you'd start a feature branch? The base branch for this PR is main.