Skip to content

Commit

Permalink
Add access_token_sha256_to_refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Feb 12, 2025
1 parent 0340f5e commit da01336
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
# so this test_application file contains only unit tests without dependency.
import hashlib
import json
import logging
import sys
Expand Down Expand Up @@ -56,6 +57,35 @@ def test_bytes_to_bytes(self):
self.assertEqual(type(_str2bytes(b"some bytes")), type(b"bytes"))


def fake_token_getter(
*,
access_token: str = "an access token",
status_code: int = 200,
expires_in: int = 3600,
token_type: str = "Bearer",
payload: dict = None,
headers: dict = None,
):
"""A helper to create a fake token getter,
which will be consumed by ClientApplication's acquire methods' post parameter.
Generic mock.patch() is inconvenient because:
1. If you patch it at or above oauth2.py _obtain_token(), token cache is not populated.
2. If you patch it at request.post(), your test cases become fragile because
more http round-trips may be added for future flows,
then your existing test case would break until you mock new round-trips.
"""
return lambda url, *args, **kwargs: MinimalResponse(
status_code=status_code,
text=json.dumps(payload or {
"access_token": access_token,
"expires_in": expires_in,
"token_type": token_type,
}),
headers=headers,
)


class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -856,3 +886,30 @@ def test_app_did_not_register_redirect_uri_should_error_out(self):
)
self.assertEqual(result.get("error"), "broker_error")


@patch("msal.authority.tenant_discovery", new=Mock(return_value={
"authorization_endpoint": "https://contoso.com/placeholder",
"token_endpoint": "https://contoso.com/placeholder",
}))
class AccessTokenToRefreshTestCase(unittest.TestCase):
def test_mismatching_hash_should_not_trigger_refresh(self):
scopes = ["scope"]
old_token = "old AT"
new_token = "new AT"
app = msal.ConfidentialClientApplication("foo", client_credential="bar")
app.acquire_token_for_client(scopes, post=fake_token_getter(access_token=old_token))
self.assertNotEqual(app.token_cache._cache, {}, "Cache should have been populated")

result = app.acquire_token_for_client(
scopes,
access_token_sha256_to_refresh="mismatching hash",
post=fake_token_getter(access_token=new_token))
self.assertEqual(result.get("access_token"), old_token, "Should hit old token")
self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_CACHE)

result = app.acquire_token_for_client(
scopes,
access_token_sha256_to_refresh=hashlib.sha256(old_token.encode()).hexdigest(),
post=fake_token_getter(access_token=new_token))
self.assertEqual(result.get("access_token"), new_token, "Should obtain new token")
self.assertEqual(result.get("token_source"), app._TOKEN_SOURCE_IDP)

0 comments on commit da01336

Please sign in to comment.