Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Credential Probe logic #5116

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ private async Task<AuthenticationResult> SendTokenRequestForManagedIdentityAsync

await ResolveAuthorityAsync().ConfigureAwait(false);

ManagedIdentityClient managedIdentityClient =
new ManagedIdentityClient(AuthenticationRequestParameters.RequestContext);
ManagedIdentityClient managedIdentityClient =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gladjohn - we discussed that you'd start a feature branch? The base branch for this PR is main.

await ManagedIdentityClient.CreateAsync(AuthenticationRequestParameters.RequestContext, cancellationToken)
.ConfigureAwait(false);

ManagedIdentityResponse managedIdentityResponse =
await managedIdentityClient
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Net.Http;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

namespace Microsoft.Identity.Client.ManagedIdentity
{
internal class CredentialManagedIdentitySource : AbstractManagedIdentity
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we just call it SLC or MSIv2 ? Credential is such an overloaded term.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, should it be ImdsV2Source? I would imagine that Arc, SF etc. will be sligthly different.

{
/// <summary>
/// Factory method to create an instance of `CredentialManagedIdentitySource`.
/// </summary>
public static AbstractManagedIdentity Create(RequestContext requestContext)
{
requestContext.Logger.Info(() => "[Managed Identity] Using credential based managed identity.");

return new CredentialManagedIdentitySource(requestContext);
}

private CredentialManagedIdentitySource(RequestContext requestContext) :
base(requestContext, ManagedIdentitySource.Credential)
{
}

/// <summary>
/// Even though the Credential flow does not use this request, we need to satisfy the abstract contract.
/// Return a minimal, valid ManagedIdentityRequest using the fixed credential endpoint.
/// </summary>
/// <param name="resource">The resource identifier (ignored in this flow).</param>
/// <returns>A ManagedIdentityRequest instance using the credential endpoint.</returns>
protected override ManagedIdentityRequest CreateRequest(string resource)
{
// Return a minimal request with the fixed credential endpoint.
return new ManagedIdentityRequest(
HttpMethod.Post,
new Uri("http://169.254.169.254/metadata/identity/credential?cred-api-version=1.0"));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Http;

namespace Microsoft.Identity.Client.ManagedIdentity
{
internal class ImdsCredentialProbeManager
{
private const string CredentialEndpoint = "http://169.254.169.254/metadata/identity/credential";
private const string ProbeBody = ".";
private const string ImdsHeader = "IMDS/";
private readonly IHttpManager _httpManager;
private readonly ILoggerAdapter _logger;

public ImdsCredentialProbeManager(IHttpManager httpManager, ILoggerAdapter logger)
{
_httpManager = httpManager ?? throw new ArgumentNullException(nameof(httpManager));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}

public async Task<bool> ExecuteAsync(CancellationToken cancellationToken = default)
{
_logger.Info("[Credential Probe] Initiating probe to IMDS credential endpoint.");

var request = new ManagedIdentityRequest(HttpMethod.Post, new Uri($"{CredentialEndpoint}?cred-api-version=1.0"))
{
Content = ProbeBody
};

HttpContent httpContent = request.CreateHttpContent();

_logger.Info($"[Credential Probe] Sending request to {CredentialEndpoint}");
_logger.Verbose(() => $"[Credential Probe] Request Headers: {string.Join(", ", request.Headers)}");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little dangerous to log headers and body. Maybe don't do it unless we really need it.

_logger.Verbose(() => $"[Credential Probe] Request Body: {ProbeBody}");

try
{
HttpResponse response = await _httpManager.SendRequestAsync(
request.ComputeUri(),
request.Headers,
httpContent,
request.Method,
_logger,
doNotThrow: true,
mtlsCertificate: null,
customHttpClient: null,
cancellationToken).ConfigureAwait(false);

LogResponseDetails(response);

return EvaluateProbeResponse(response);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth recording in telemetry about this? I guess we will know MSAL version and MSI source already - will this be enough?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your spec states that a retry mechanism needs to be used. This doesn't seem implemented.

image

}
catch (Exception ex)
{
_logger.Error($"[Credential Probe] Exception during probe: {ex.Message}");
Copy link
Member

@bgavrilMS bgavrilMS Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really an error is it? It is part of the normal flow. It will confuse ppl. We've did this before in regional and ppl complained that the logs are too noisy. Pls use info level.

I also recommend you do not print the full exception here, just the message should be suficient.

_logger.Error($"[Credential Probe] Stack Trace: {ex.StackTrace}");
return false;
}
}

private void LogResponseDetails(HttpResponse response)
{
if (response == null)
{
_logger.Error("[Credential Probe] No response received from the server.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not an error ?

return;
}

_logger.Info($"[Credential Probe] Response Status Code: {response.StatusCode}");
_logger.Verbose(() => $"[Credential Probe] Response Headers: {string.Join(", ", response.HeadersAsDictionary)}");

if (response.Body != null)
{
_logger.Verbose(() => $"[Credential Probe] Response Body: {response.Body}");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid printing entire reponse body. it could have PII or worse, some secrets.

}
}

private bool EvaluateProbeResponse(HttpResponse response)
{
if (response == null)
{
_logger.Error("[Credential Probe] No response received from the server.");
return false;
}

_logger.Info($"[Credential Probe] Evaluating response from credential endpoint. Status Code: {response.StatusCode}");

if (response.HeadersAsDictionary.TryGetValue("Server", out string serverHeader) &&
serverHeader.StartsWith(ImdsHeader, StringComparison.OrdinalIgnoreCase))
{
_logger.Info($"[Credential Probe] Credential endpoint supported. Server Header: {serverHeader}");
return true;
}

_logger.Warning($"[Credential Probe] Credential endpoint not supported. Status Code: {response.StatusCode}");
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.Identity.Client.PlatformsCommon.Shared;
using System.IO;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Http;

namespace Microsoft.Identity.Client.ManagedIdentity
{
Expand All @@ -19,38 +20,74 @@ internal class ManagedIdentityClient
{
private const string WindowsHimdsFilePath = "%Programfiles%\\AzureConnectedMachineAgent\\himds.exe";
private const string LinuxHimdsFilePath = "/opt/azcmagent/bin/himds";

// Cache for the managed identity source
private static ManagedIdentitySource? s_cachedManagedIdentitySource;
private static readonly SemaphoreSlim s_credentialSemaphore = new(1, 1);
private readonly AbstractManagedIdentity _identitySource;

public ManagedIdentityClient(RequestContext requestContext)
internal static async Task<ManagedIdentityClient> CreateAsync(RequestContext requestContext, CancellationToken cancellationToken = default)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the requestContext has the cancellation token, but it's ok to move it as last param. But I recommend you do not use a default.

{
using (requestContext.Logger.LogMethodDuration())
if (requestContext == null)
{
_identitySource = SelectManagedIdentitySource(requestContext);
throw new ArgumentNullException(nameof(requestContext), "RequestContext cannot be null.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't be null. Don't throw.

}

requestContext.Logger?.Info("[ManagedIdentityClient] Creating ManagedIdentityClient.");

AbstractManagedIdentity identitySource = await SelectManagedIdentitySourceAsync(requestContext, cancellationToken).ConfigureAwait(false);

requestContext.Logger?.Info($"[ManagedIdentityClient] Managed identity source selected: {identitySource.GetType().Name}.");

return new ManagedIdentityClient(identitySource);
}

private ManagedIdentityClient(AbstractManagedIdentity identitySource)
{
_identitySource = identitySource ?? throw new ArgumentNullException(nameof(identitySource), "Identity source cannot be null.");
}

/// <summary>
/// Resets the cached managed identity source. Used only for testing purposes.
/// </summary>
internal static void ResetManagedIdentitySourceCache()
{
s_cachedManagedIdentitySource = null;
}

internal Task<ManagedIdentityResponse> SendTokenRequestForManagedIdentityAsync(AcquireTokenForManagedIdentityParameters parameters, CancellationToken cancellationToken)
{
return _identitySource.AuthenticateAsync(parameters, cancellationToken);
}

// This method tries to create managed identity source for different sources, if none is created then defaults to IMDS.
private static AbstractManagedIdentity SelectManagedIdentitySource(RequestContext requestContext)
/// <summary>
/// This method tries to create managed identity source for different sources.
/// If none is created then defaults to IMDS.
/// </summary>
/// <param name="requestContext"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
private static async Task<AbstractManagedIdentity> SelectManagedIdentitySourceAsync(RequestContext requestContext, CancellationToken cancellationToken = default)
{
return GetManagedIdentitySource(requestContext.Logger) switch
ManagedIdentitySource source = await GetManagedIdentitySourceAsync(requestContext.ServiceBundle, cancellationToken).ConfigureAwait(false);

return source switch
{
ManagedIdentitySource.ServiceFabric => ServiceFabricManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.AppService => AppServiceManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.MachineLearning => MachineLearningManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.CloudShell => CloudShellManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.AzureArc => AzureArcManagedIdentitySource.Create(requestContext),
ManagedIdentitySource.Credential => CredentialManagedIdentitySource.Create(requestContext),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming: avoid using "credential". Maybe "ImdsV2" or smth

_ => new ImdsManagedIdentitySource(requestContext)
};
}

// Detect managed identity source based on the availability of environment variables.
// The result of this method is not cached because reading environment variables is cheap.
// This method is perf sensitive any changes should be benchmarked.
/// <summary>
/// Compute the managed identity source based on the environment variables.
/// </summary>
/// <param name="logger"></param>
/// <returns></returns>
internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter logger = null)
{
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
Expand Down Expand Up @@ -97,6 +134,128 @@ internal static ManagedIdentitySource GetManagedIdentitySource(ILoggerAdapter lo
}
}

/// <summary>
/// Compute the managed identity source based on the environment variables and the probe.
/// </summary>
/// <param name="serviceBundle"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
public static async Task<ManagedIdentitySource> GetManagedIdentitySourceAsync(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic is too complex. Why do we have 2 public methods - one with the probe and one without the probe?

IServiceBundle serviceBundle,
CancellationToken cancellationToken = default)
{
if (serviceBundle == null)
{
throw new ArgumentNullException(nameof(serviceBundle), "ServiceBundle is required to initialize the probe manager.");
}

ILoggerAdapter logger = serviceBundle.ApplicationLogger;

logger.Verbose(() => s_cachedManagedIdentitySource.HasValue
? "[Managed Identity] Using cached managed identity source."
: "[Managed Identity] Computing managed identity source asynchronously.");

if (s_cachedManagedIdentitySource.HasValue)
{
return s_cachedManagedIdentitySource.Value;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't Azure SDK need a way to reset this? For their tests?

Maybe it's time we expose a "ResetCachesForTest()" method in the extensibility namespace.

}

// Use SemaphoreSlim to prevent multiple threads from computing at the same time
logger.Verbose(() => "[Managed Identity] Entering managed identity source semaphore.");
await s_credentialSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
logger.Verbose(() => "[Managed Identity] Entered managed identity source semaphore.");

try
{
// Ensure another thread didn't set this while waiting on semaphore
if (s_cachedManagedIdentitySource.HasValue)
{
return s_cachedManagedIdentitySource.Value;
}

// Initialize probe manager
var probeManager = new ImdsCredentialProbeManager(
serviceBundle.HttpManager,
serviceBundle.ApplicationLogger);

// Compute the managed identity source
s_cachedManagedIdentitySource = await ComputeManagedIdentitySourceAsync(
probeManager,
serviceBundle.ApplicationLogger,
cancellationToken).ConfigureAwait(false);

logger.Info($"[Managed Identity] Managed identity source determined: {s_cachedManagedIdentitySource.Value}.");

return s_cachedManagedIdentitySource.Value;
}
finally
{
s_credentialSemaphore.Release();
logger.Verbose(() => "[Managed Identity] Released managed identity source semaphore.");
}
}

/// <summary>
/// Compute the managed identity source based on the environment variables and the probe.
/// </summary>
/// <param name="imdsCredentialProbeManager"></param>
/// <param name="logger"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
private static async Task<ManagedIdentitySource> ComputeManagedIdentitySourceAsync(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please have a look at all these method names and consolidate. They are so many "select source", "compute source", etc.

ImdsCredentialProbeManager imdsCredentialProbeManager,
ILoggerAdapter logger,
CancellationToken cancellationToken)
{
string identityEndpoint = EnvironmentVariables.IdentityEndpoint;
string identityHeader = EnvironmentVariables.IdentityHeader;
string identityServerThumbprint = EnvironmentVariables.IdentityServerThumbprint;
string msiEndpoint = EnvironmentVariables.MsiEndpoint;
string imdsEndpoint = EnvironmentVariables.ImdsEndpoint;
string msiSecretMachineLearning = EnvironmentVariables.MsiSecret;

if (!string.IsNullOrEmpty(identityEndpoint) && !string.IsNullOrEmpty(identityHeader))
{
if (!string.IsNullOrEmpty(identityServerThumbprint))
{
return ManagedIdentitySource.ServiceFabric;
}
else
{
return ManagedIdentitySource.AppService;
}
}
else if (!string.IsNullOrEmpty(msiSecretMachineLearning) && !string.IsNullOrEmpty(msiEndpoint))
{
return ManagedIdentitySource.MachineLearning;
}
else if (!string.IsNullOrEmpty(msiEndpoint))
{
return ManagedIdentitySource.CloudShell;
}
else if (ValidateAzureArcEnvironment(identityEndpoint, imdsEndpoint, logger))
{
return ManagedIdentitySource.AzureArc;
}
else
{
logger?.Info("[Managed Identity] Probing for credential endpoint.");
bool isSuccess = await imdsCredentialProbeManager.ExecuteAsync(cancellationToken).ConfigureAwait(false);

if (isSuccess)
{
logger?.Info("[Managed Identity] Credential endpoint detected.");
return ManagedIdentitySource.Credential;
}
else
{
logger?.Verbose(() => "[Managed Identity] Defaulting to IMDS as credential endpoint not detected.");
return ManagedIdentitySource.DefaultToImds;
}
}
}

// Method to return true if a file exists and is not empty to validate the Azure arc environment.
private static bool ValidateAzureArcEnvironment(string identityEndpoint, string imdsEndpoint, ILoggerAdapter logger)
{
Expand All @@ -118,8 +277,8 @@ private static bool ValidateAzureArcEnvironment(string identityEndpoint, string
{
logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is available through file detection.");
return true;
}
}

logger?.Verbose(() => "[Managed Identity] Azure Arc managed identity is not available.");
return false;
}
Expand Down
Loading
Loading