Skip to content

Commit

Permalink
Merge pull request #261 from sirosen/feature/tok-by-scopes
Browse files Browse the repository at this point in the history
First draft of `by_scopes` attr on token response
  • Loading branch information
jaswilli authored Dec 13, 2017
2 parents 803aa67 + 8ae6398 commit 33e5ace
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 4 deletions.
98 changes: 96 additions & 2 deletions globus_sdk/auth/token_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import requests
import time
import six

import jwt

Expand All @@ -24,10 +25,71 @@ def _convert_token_info_dict(source_dict):
'access_token': source_dict['access_token'],
'refresh_token': source_dict.get('refresh_token'),
'token_type': source_dict.get('token_type'),
'expires_at_seconds': int(time.time() + expires_in)
'expires_at_seconds': int(time.time() + expires_in),
'resource_server': source_dict['resource_server']
}


class _ByScopesGetter(object):
"""
A fancy dict-like object for looking up token data by scope name.
Allows usage like
>>> tokens = OAuthTokenResponse(...)
>>> tok = tokens.by_scopes['openid profile']['access_token']
"""
def __init__(self, scope_map):
self.scope_map = scope_map

def __str__(self):
return json.dumps(self.scope_map)

def __iter__(self):
"""iteration gets you every individual scope"""
return iter(self.scope_map.keys())

def __getitem__(self, scopename):
if not isinstance(scopename, six.string_types):
raise KeyError('by_scopes cannot contain non-string value "{}"'
.format(scopename))

# split on spaces
scopes = scopename.split()
# collect every matching token in a set to dedup
# but collect actual results (dicts) in a list
rs_names = set()
toks = []
for scope in scopes:
try:
rs_names.add(self.scope_map[scope]['resource_server'])
toks.append(self.scope_map[scope])
except KeyError:
raise KeyError(('Scope specifier "{}" contains scope "{}" '
"which was not found"
).format(scopename, scope))
# if there isn't exactly 1 token, it's an error
if len(rs_names) != 1:
raise KeyError(
'Scope specifier "{}" did not match exactly one token!'
.format(scopename))
# pop the only element in the set
return toks.pop()

def __contains__(self, item):
"""
contains is driven by checking against getitem
that way, the definitions are always "in sync" if we update them in
the future
"""
try:
self.__getitem__(item)
return True
except KeyError:
pass

return False


class OAuthTokenResponse(GlobusHTTPResponse):
"""
Class for responses from the OAuth2 code for tokens exchange used in
Expand All @@ -36,11 +98,20 @@ class OAuthTokenResponse(GlobusHTTPResponse):
def __init__(self, *args, **kwargs):
GlobusHTTPResponse.__init__(self, *args, **kwargs)
self._init_rs_dict()
self._init_scopes_getter()

def _init_scopes_getter(self):
scope_map = {}
for rs, tok_data in self._by_resource_server.items():
for s in tok_data["scope"].split():
scope_map[s] = tok_data
self._by_scopes = _ByScopesGetter(scope_map)

def _init_rs_dict(self):
# call the helper at the top level
self._by_resource_server = {
self['resource_server']: _convert_token_info_dict(self)}
self['resource_server']: _convert_token_info_dict(self)
}
# call the helper on everything in 'other_tokens'
self._by_resource_server.update(dict(
(unprocessed_item['resource_server'],
Expand All @@ -59,6 +130,29 @@ def by_resource_server(self):
"""
return self._by_resource_server

@property
def by_scopes(self):
"""
Representation of the token response in a dict-like object indexed by
scope name (or even space delimited scope names, so long as they match
the same token).
If you request scopes `scope1 scope2 scope3`, where `scope1` and
`scope2` are for the same service (and therefore map to the same
token), but `scope3` is for a different service, the following forms of
access are valid:
>>> tokens = ...
>>> # single scope
>>> token_data = tokens.by_scopes['scope1']
>>> token_data = tokens.by_scopes['scope2']
>>> token_data = tokens.by_scopes['scope3']
>>> # matching scopes
>>> token_data = tokens.by_scopes['scope1 scope2']
>>> token_data = tokens.by_scopes['scope2 scope1']
"""
return self._by_scopes

def decode_id_token(self, auth_client=None):
"""
A parsed ID Token (OIDC) as a dict.
Expand Down
64 changes: 62 additions & 2 deletions tests/unit/responses/test_token_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def setUp(self):
"id_token": "invalid_id_token",
"access_token": SDKTESTER1A_ID_ACCESS_TOKEN}
self.other_token2 = { # valid id_token with invalid access_token
"resource_server": "server3", "expires_in": 30, "scope": "scope3",
"resource_server": "server3", "expires_in": 30,
"scope": "scope3 scope4",
"refresh_token": "RT3", "other_tokens": [], "token_type": "3",
"id_token": SDKTESTER1A_NATIVE1_ID_TOKEN,
"access_token": "invalid_access_token"}
Expand Down Expand Up @@ -114,6 +115,35 @@ def test_by_resource_server(self):
self.assertIn(server_data["expires_at_seconds"],
(expected - 1, expected, expected + 1))

def test_by_scopes(self):
"""
Gets by_scopes attribute from test response,
Confirms expected values found for top and other tokens
"""
by_scopes = self.response.by_scopes

# confirm data by server matches known token values
for scope, token in [("scope1", self.top_token),
("scope2", self.other_token1),
("scope3", self.other_token2),
("scope4", self.other_token2),
("scope3 scope4", self.other_token2),
("scope4 scope3", self.other_token2)]:
scope_data = by_scopes[scope]
for key in ["scope", "access_token",
"refresh_token", "token_type"]:
self.assertEqual(scope_data[key], token[key])
# assumes test runs within 1 second range
expected = int(time.time()) + token["expires_in"]
self.assertIn(scope_data["expires_at_seconds"],
(expected - 1, expected, expected + 1))

self.assertIn('scope1', by_scopes)
self.assertIn('scope3', by_scopes)
self.assertNotIn('scope1 scope2', by_scopes)
self.assertNotIn('scope1 scope3', by_scopes)
self.assertIn('scope4 scope3', by_scopes)

@retry_errors()
def test_decode_id_token_invalid_id(self):
"""
Expand Down Expand Up @@ -159,7 +189,8 @@ def setUp(self):
"resource_server": "server2", "expires_in": 20, "scope": "scope2",
"access_token": "AT2", "refresh_token": "RT2", "token_type": "2"}
self.token3 = {
"resource_server": "server3", "expires_in": 30, "scope": "scope3",
"resource_server": "server3", "expires_in": 30,
"scope": "scope3 scope4",
"access_token": "AT3", "refresh_token": "RT3", "token_type": "3"}

# create the response
Expand Down Expand Up @@ -188,3 +219,32 @@ def test_by_resource_server(self):
expected = int(time.time()) + token["expires_in"]
self.assertIn(server_data["expires_at_seconds"],
(expected - 1, expected, expected + 1))

def test_by_scopes(self):
"""
Gets by_scopes attribute from test response,
Confirms expected values found for top and other tokens
"""
by_scopes = self.response.by_scopes

# confirm data by server matches known token values
for scope, token in [("scope1", self.token1),
("scope2", self.token2),
("scope3", self.token3),
("scope4", self.token3),
("scope3 scope4", self.token3),
("scope4 scope3", self.token3)]:
scope_data = by_scopes[scope]
for key in ["scope", "access_token",
"refresh_token", "token_type"]:
self.assertEqual(scope_data[key], token[key])
# assumes test runs within 1 second range
expected = int(time.time()) + token["expires_in"]
self.assertIn(scope_data["expires_at_seconds"],
(expected - 1, expected, expected + 1))

self.assertIn('scope1', by_scopes)
self.assertIn('scope3', by_scopes)
self.assertNotIn('scope1 scope2', by_scopes)
self.assertNotIn('scope1 scope3', by_scopes)
self.assertIn('scope4 scope3', by_scopes)

0 comments on commit 33e5ace

Please sign in to comment.