Skip to content

Commit

Permalink
Merge pull request #85 from AzureAD/release-0.6.1
Browse files Browse the repository at this point in the history
Release 0.6.1
  • Loading branch information
rayluo authored Aug 13, 2019
2 parents bb80636 + 0e19bc6 commit 4b34fd6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 11 deletions.
2 changes: 1 addition & 1 deletion msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


# The __init__.py will import this. Not the other way around.
__version__ = "0.6.0"
__version__ = "0.6.1"

logger = logging.getLogger(__name__)

Expand Down
25 changes: 21 additions & 4 deletions msal/oauth2cli/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,29 @@

from . import oauth2

def decode_part(raw, encoding="utf-8"):
"""Decode a part of the JWT.
def base64decode(raw):
"""A helper can handle a padding-less raw input"""
JWT is encoded by padding-less base64url,
based on `JWS specs <https://tools.ietf.org/html/rfc7515#appendix-C>`_.
:param encoding:
If you are going to decode the first 2 parts of a JWT, i.e. the header
or the payload, the default value "utf-8" would work fine.
If you are going to decode the last part i.e. the signature part,
it is a binary string so you should use `None` as encoding here.
"""
raw += '=' * (-len(raw) % 4) # https://stackoverflow.com/a/32517907/728675
return base64.b64decode(raw).decode("utf-8")
raw = str(
# On Python 2.7, argument of urlsafe_b64decode must be str, not unicode.
# This is not required on Python 3.
raw)
output = base64.urlsafe_b64decode(raw)
if encoding:
output = output.decode(encoding)
return output

base64decode = decode_part # Obsolete. For backward compatibility only.

def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None):
"""Decodes and validates an id_token and returns its claims as a dictionary.
Expand All @@ -19,7 +36,7 @@ def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None)
and it may contain other optional content such as "preferred_username",
`maybe more <https://openid.net/specs/openid-connect-core-1_0.html#Claims>`_
"""
decoded = json.loads(base64decode(id_token.split('.')[1]))
decoded = json.loads(decode_part(id_token.split('.')[1]))
err = None # https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
if issuer and issuer != decoded["iss"]:
# https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse
Expand Down
4 changes: 2 additions & 2 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

from .authority import canonicalize
from .oauth2cli.oidc import base64decode, decode_id_token
from .oauth2cli.oidc import decode_part, decode_id_token


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,7 +124,7 @@ def add(self, event, now=None):
client_info = {}
home_account_id = None # It would remain None in client_credentials flow
if "client_info" in response: # We asked for it, and AAD will provide it
client_info = json.loads(base64decode(response["client_info"]))
client_info = json.loads(decode_part(response["client_info"]))
home_account_id = "{uid}.{utid}".format(**client_info)
elif id_token_claims: # This would be an end user on ADFS-direct scenario
client_info["uid"] = id_token_claims.get("sub")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_assertion.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import json

from msal.oauth2cli import JwtSigner
from msal.oauth2cli.oidc import base64decode
from msal.oauth2cli import JwtAssertionCreator
from msal.oauth2cli.oidc import decode_part

from tests import unittest


class AssertionTestCase(unittest.TestCase):
def test_extra_claims(self):
assertion = JwtSigner(key=None, algorithm="none").sign_assertion(
assertion = JwtAssertionCreator(key=None, algorithm="none").sign_assertion(
"audience", "issuer", additional_claims={"client_ip": "1.2.3.4"})
payload = json.loads(base64decode(assertion.split(b'.')[1].decode('utf-8')))
payload = json.loads(decode_part(assertion.split(b'.')[1].decode('utf-8')))
self.assertEqual("1.2.3.4", payload.get("client_ip"))

0 comments on commit 4b34fd6

Please sign in to comment.