Skip to content

Commit aadacda

Browse files
committed
Add user_loader feature (refs #49)
1 parent 5f47c0c commit aadacda

File tree

9 files changed

+306
-21
lines changed

9 files changed

+306
-21
lines changed

flask_jwt_extended/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .utils import (
77
create_refresh_token, create_access_token, get_jwt_identity,
88
get_jwt_claims, set_access_cookies, set_refresh_cookies,
9-
unset_jwt_cookies, get_raw_jwt
9+
unset_jwt_cookies, get_raw_jwt, get_current_user, current_user
1010
)
1111
from .blacklist import (
1212
revoke_token, unrevoke_token, get_stored_tokens, get_all_stored_tokens,

flask_jwt_extended/default_callbacks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,12 @@ def default_revoked_token_callback():
7474
return a general error message with a 401 status code
7575
"""
7676
return jsonify({'msg': 'Token has been revoked'}), 401
77+
78+
79+
def default_user_loader_error_callback(identity):
80+
"""
81+
By default, if a user_loader callback is defined and the callback
82+
function returns None, we return a general error message with a 401
83+
status code
84+
"""
85+
return jsonify({'msg': "Error loading the user {}".format(identity)}), 401

flask_jwt_extended/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,11 @@ class FreshTokenRequired(JWTExtendedException):
5454
protected by fresh_jwt_required
5555
"""
5656
pass
57+
58+
59+
class UserLoadError(JWTExtendedException):
60+
"""
61+
Error raised when a user_loader callback function returns None, indicating
62+
that it cannot or will not load a user for the given identity.
63+
"""
64+
pass

flask_jwt_extended/jwt_manager.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@
66
from flask_jwt_extended.config import config
77
from flask_jwt_extended.exceptions import (
88
JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
9-
RevokedTokenError, FreshTokenRequired, CSRFError
9+
RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError
1010
)
1111
from flask_jwt_extended.default_callbacks import (
1212
default_expired_token_callback, default_user_claims_callback,
1313
default_user_identity_callback, default_invalid_token_callback,
14-
default_unauthorized_callback,
15-
default_needs_fresh_token_callback,
16-
default_revoked_token_callback
14+
default_unauthorized_callback, default_needs_fresh_token_callback,
15+
default_revoked_token_callback, default_user_loader_error_callback
1716
)
1817
from flask_jwt_extended.tokens import (
19-
encode_refresh_token, decode_jwt,
20-
encode_access_token
18+
encode_refresh_token, decode_jwt, encode_access_token
2119
)
20+
from flask_jwt_extended.utils import get_jwt_identity
2221

2322

2423
class JWTManager(object):
@@ -39,6 +38,8 @@ def __init__(self, app=None):
3938
self._unauthorized_callback = default_unauthorized_callback
4039
self._needs_fresh_token_callback = default_needs_fresh_token_callback
4140
self._revoked_token_callback = default_revoked_token_callback
41+
self._user_loader_callback = None
42+
self._user_loader_error_callback = default_user_loader_error_callback
4243

4344
# Register this extension with the flask app now (if it is provided)
4445
if app is not None:
@@ -101,6 +102,14 @@ def handle_revoked_token_error(e):
101102
def handle_fresh_token_required(e):
102103
return self._needs_fresh_token_callback()
103104

105+
@app.errorhandler(UserLoadError)
106+
def handler_user_load_error(e):
107+
# The identity is already saved before this exception was raised,
108+
# otherwise a different exception would be raised, which is why we
109+
# can safely call get_jwt_identity() here
110+
identity = get_jwt_identity()
111+
return self._user_loader_error_callback(identity)
112+
104113
@staticmethod
105114
def _set_default_configuration_options(app):
106115
"""
@@ -244,6 +253,50 @@ def revoked_token_loader(self, callback):
244253
self._revoked_token_callback = callback
245254
return callback
246255

256+
def user_loader_callback_loader(self, callback):
257+
"""
258+
Sets the callback method to be called to load a user on a protected
259+
endpoint.
260+
261+
By default this is not is not used.
262+
263+
If a callback method is passed in here, it must take one argument,
264+
which is the identity of the user to load. It must return the user
265+
object, or None in the case of an error (which will cause the TODO
266+
error handler to be hit)
267+
"""
268+
self._user_loader_callback = callback
269+
return callback
270+
271+
def user_loader_error_loader(self, callback):
272+
"""
273+
Sets the callback method to be called if a user fails or is refused
274+
to load when calling the _user_loader_callback function (indicated by
275+
that function returning None)
276+
277+
The default implementation will return json:
278+
'{"msg": "Error loading the user <identity>"}' with a 400 status code.
279+
280+
Callback must be a function that takes one argument, the identity of the
281+
user who failed to load.
282+
"""
283+
self._user_loader_error_callback = callback
284+
return callback
285+
286+
def has_user_loader(self):
287+
"""
288+
Returns True if a user_loader_callback has been defined in this
289+
application, False otherwise
290+
"""
291+
return self._user_loader_callback is not None
292+
293+
def user_loader(self, identity):
294+
"""
295+
Calls the _user_loader_callback function (if it is defined) and returns
296+
the resulting user from this callback.
297+
"""
298+
return self._user_loader_callback(identity)
299+
247300
def create_refresh_token(self, identity):
248301
"""
249302
Creates a new refresh token

flask_jwt_extended/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from flask import current_app
2+
from werkzeug.local import LocalProxy
3+
24
try:
35
from flask import _app_ctx_stack as ctx_stack
46
except ImportError: # pragma: no cover
@@ -8,6 +10,10 @@
810
from flask_jwt_extended.tokens import decode_jwt
911

1012

13+
# Proxy to access the current user
14+
current_user = LocalProxy(lambda: get_current_user())
15+
16+
1117
def get_raw_jwt():
1218
"""
1319
Returns the python dictionary which has all of the data in this JWT. If no
@@ -32,6 +38,15 @@ def get_jwt_claims():
3238
return get_raw_jwt().get('user_claims', {})
3339

3440

41+
def get_current_user():
42+
"""
43+
Returns the loaded user from a user_loader callback in a protected endpoint.
44+
If no user was loaded, or if no user_loader callback was defined, this will
45+
return None
46+
"""
47+
return getattr(ctx_stack.top, 'jwt_user', None)
48+
49+
3550
def get_jti(encoded_token):
3651
"""
3752
Returns the JTI given the JWT encoded token
@@ -60,6 +75,16 @@ def create_refresh_token(*args, **kwargs):
6075
return jwt_manager.create_refresh_token(*args, **kwargs)
6176

6277

78+
def user_loader(*args, **kwargs):
79+
jwt_manager = _get_jwt_manager()
80+
return jwt_manager.user_loader(*args, **kwargs)
81+
82+
83+
def has_user_loader(*args, **kwargs):
84+
jwt_manager = _get_jwt_manager()
85+
return jwt_manager.has_user_loader(*args, **kwargs)
86+
87+
6388
def get_csrf_token(encoded_token):
6489
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
6590
return token['csrf']

flask_jwt_extended/view_decorators.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from flask_jwt_extended.config import config
1212
from flask_jwt_extended.exceptions import (
1313
InvalidHeaderError, NoAuthorizationError, WrongTokenError,
14-
FreshTokenRequired, CSRFError
14+
FreshTokenRequired, CSRFError, UserLoadError
1515
)
1616
from flask_jwt_extended.tokens import decode_jwt
17+
from flask_jwt_extended.utils import has_user_loader, user_loader
1718

1819

1920
def jwt_required(fn):
@@ -28,10 +29,9 @@ def jwt_required(fn):
2829
"""
2930
@wraps(fn)
3031
def wrapper(*args, **kwargs):
31-
# Save the jwt in the context so that it can be accessed later by
32-
# the various endpoints that is using this decorator
3332
jwt_data = _decode_jwt_from_request(request_type='access')
3433
ctx_stack.top.jwt = jwt_data
34+
_load_user(jwt_data['identity'])
3535
return fn(*args, **kwargs)
3636
return wrapper
3737

@@ -49,15 +49,11 @@ def jwt_optional(fn):
4949
@wraps(fn)
5050
def wrapper(*args, **kwargs):
5151
try:
52-
# If an acceptable JWT is found in the request, put it into
53-
# the application context
5452
jwt_data = _decode_jwt_from_request(request_type='access')
5553
ctx_stack.top.jwt = jwt_data
54+
_load_user(jwt_data['identity'])
5655
except NoAuthorizationError:
57-
# Allow request to proceed if no authorization header is present
58-
# in the request, but don't modify application context
5956
pass
60-
# Return the decorated function in either case
6157
return fn(*args, **kwargs)
6258
return wrapper
6359

@@ -78,9 +74,8 @@ def wrapper(*args, **kwargs):
7874
if not jwt_data['fresh']:
7975
raise FreshTokenRequired('Fresh token required')
8076

81-
# Save the jwt in the context so that it can be accessed later by
82-
# the various endpoints that is using this decorator
8377
ctx_stack.top.jwt = jwt_data
78+
_load_user(jwt_data['identity'])
8479
return fn(*args, **kwargs)
8580
return wrapper
8681

@@ -93,14 +88,22 @@ def jwt_refresh_token_required(fn):
9388
"""
9489
@wraps(fn)
9590
def wrapper(*args, **kwargs):
96-
# Save the jwt in the context so that it can be accessed later by
97-
# the various endpoints that is using this decorator
9891
jwt_data = _decode_jwt_from_request(request_type='refresh')
9992
ctx_stack.top.jwt = jwt_data
93+
_load_user(jwt_data['identity'])
10094
return fn(*args, **kwargs)
10195
return wrapper
10296

10397

98+
def _load_user(identity):
99+
if has_user_loader():
100+
user = user_loader(identity)
101+
if user is None:
102+
raise UserLoadError("user_loader returned None for {}".format(identity))
103+
else:
104+
ctx_stack.top.jwt_user = user
105+
106+
104107
def _decode_jwt_from_headers():
105108
header_name = config.header_name
106109
header_type = config.header_type

tests/test_jwt_manager.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ def test_class_init(self):
3232
def test_default_user_claims_callback(self):
3333
identity = 'foobar'
3434
m = JWTManager(self.app)
35-
assert m._user_claims_callback(identity) == {}
35+
self.assertEqual(m._user_claims_callback(identity), {})
36+
37+
def test_default_user_identity_callback(self):
38+
identity = 'foobar'
39+
m = JWTManager(self.app)
40+
self.assertEqual(m._user_identity_callback(identity), identity)
3641

3742
def test_default_expired_token_callback(self):
3843
with self.app.test_request_context():
@@ -80,6 +85,24 @@ def test_default_revoked_token_callback(self):
8085
self.assertEqual(status_code, 401)
8186
self.assertEqual(data, {'msg': 'Token has been revoked'})
8287

88+
def test_default_user_loader_callback(self):
89+
m = JWTManager(self.app)
90+
self.assertEqual(m._user_loader_callback, None)
91+
92+
def test_default_user_loader_error_callback(self):
93+
with self.app.test_request_context():
94+
identity = 'foobar'
95+
m = JWTManager(self.app)
96+
result = m._user_loader_error_callback(identity)
97+
status_code, data = self._parse_callback_result(result)
98+
99+
self.assertEqual(status_code, 401)
100+
self.assertEqual(data, {'msg': 'Error loading the user foobar'})
101+
102+
def test_default_has_user_loader(self):
103+
m = JWTManager(self.app)
104+
self.assertEqual(m.has_user_loader(), False)
105+
83106
def test_custom_user_claims_callback(self):
84107
identity = 'foobar'
85108
m = JWTManager(self.app)
@@ -159,3 +182,33 @@ def custom_revoken_token():
159182

160183
self.assertEqual(status_code, 422)
161184
self.assertEqual(data, {'err': 'Nice knowing you!'})
185+
186+
def test_custom_user_loader(self):
187+
with self.app.test_request_context():
188+
m = JWTManager(self.app)
189+
190+
@m.user_loader_callback_loader
191+
def custom_user_loader(identity):
192+
if identity == 'foo':
193+
return None
194+
return identity
195+
196+
identity = 'foobar'
197+
result = m._user_loader_callback(identity)
198+
self.assertEqual(result, identity)
199+
self.assertEqual(m.has_user_loader(), True)
200+
201+
def test_custom_user_loader_error_callback(self):
202+
with self.app.test_request_context():
203+
m = JWTManager(self.app)
204+
205+
@m.user_loader_error_loader
206+
def custom_user_loader_error(identity):
207+
return jsonify({'msg': 'Not found'}), 404
208+
209+
identity = 'foobar'
210+
result = m._user_loader_error_callback(identity)
211+
status_code, data = self._parse_callback_result(result)
212+
213+
self.assertEqual(status_code, 404)
214+
self.assertEqual(data, {'msg': 'Not found'})

tests/test_protected_endpoints.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def partially_protected():
6262
return jsonify({'msg': "protected hello world"})
6363
return jsonify({'msg': "unprotected hello world"})
6464

65-
6665
def _jwt_post(self, url, jwt):
6766
response = self.client.post(url, content_type='application/json',
6867
headers={'Authorization': 'Bearer {}'.format(jwt)})

0 commit comments

Comments
 (0)