Skip to content

Commit c4ecf4f

Browse files
committed
Helper decorater for handling callback functions on errors
1 parent 570f67d commit c4ecf4f

File tree

2 files changed

+95
-74
lines changed

2 files changed

+95
-74
lines changed

flask_jwt_extended/exceptions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,26 @@ class NoAuthHeaderError(JWTExtendedException):
3131
An error getting header information from a request
3232
"""
3333
pass
34+
35+
36+
class WrongTokenError(JWTExtendedException):
37+
"""
38+
Error raised when attempting to use a refresh token to access an endpoint
39+
or vice versa
40+
"""
41+
pass
42+
43+
44+
class RevokedTokenError(JWTExtendedException):
45+
"""
46+
Error raised when a revoked token attempt to access a protected endpoint
47+
"""
48+
pass
49+
50+
51+
class FreshTokenRequired(JWTExtendedException):
52+
"""
53+
Error raised when a valid, non-fresh JWT attempt to access an endpoint
54+
protected by fresh_jwt_required
55+
"""
56+
pass

flask_jwt_extended/utils.py

Lines changed: 72 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from flask_jwt_extended.config import ALGORITHM, REFRESH_EXPIRES, ACCESS_EXPIRES, \
1616
BLACKLIST_ENABLED, BLACKLIST_STORE, BLACKLIST_TOKEN_CHECKS
1717
from flask_jwt_extended.exceptions import JWTEncodeError, JWTDecodeError, \
18-
InvalidHeaderError, NoAuthHeaderError
19-
18+
InvalidHeaderError, NoAuthHeaderError, WrongTokenError, RevokedTokenError, \
19+
FreshTokenRequired
2020

2121
# Proxy for accessing the identity of the JWT in this context
2222
jwt_identity = LocalProxy(lambda: _get_identity())
@@ -133,12 +133,7 @@ def _decode_jwt(token, secret, algorithm):
133133
return data
134134

135135

136-
def _verify_jwt_from_request(secret):
137-
"""
138-
Returns the encoded JWT string from the request
139-
140-
:return: Encoded jwt string, or None if it does not exist
141-
"""
136+
def _decode_jwt_from_request():
142137
# Verify we have the auth header
143138
auth_header = request.headers.get('Authorization', None)
144139
if not auth_header:
@@ -154,9 +149,56 @@ def _verify_jwt_from_request(secret):
154149
raise InvalidHeaderError(msg)
155150

156151
token = parts[1]
152+
secret = _get_secret_key()
157153
return _decode_jwt(token, secret, 'HS256')
158154

159155

156+
def _handle_callbacks_on_error(fn):
157+
"""
158+
Helper decorator that will catch any exceptions we expect to encounter
159+
when dealing with a JWT, and call the appropriate callback function for
160+
handling that error. Callback functions can be set in using the *_loader
161+
methods in jwt_manager.
162+
"""
163+
@wraps(fn)
164+
def wrapper(*args, **kwargs):
165+
m = current_app.jwt_manager
166+
167+
try:
168+
return fn(*args, **kwargs)
169+
except NoAuthHeaderError:
170+
return m.unauthorized_callback()
171+
except jwt.ExpiredSignatureError:
172+
return m.expired_token_callback()
173+
except (InvalidHeaderError, jwt.InvalidTokenError, JWTDecodeError,
174+
WrongTokenError) as e:
175+
return m.invalid_token_callback(str(e))
176+
except RevokedTokenError:
177+
return m.blacklisted_token_callback()
178+
except FreshTokenRequired:
179+
return m.token_needs_refresh_callback()
180+
return wrapper
181+
182+
183+
def _check_blacklist(jwt_data):
184+
if not _blacklist_enabled():
185+
return
186+
187+
store = _get_blacklist_store()
188+
token_type = jwt_data['type']
189+
jti = jwt_data['jti']
190+
191+
if token_type == 'access' and _blacklist_checks() == 'all':
192+
token_status = store[jti]
193+
if token_status != 'active':
194+
raise RevokedTokenError('{} has been revoked'.format)
195+
196+
if token_type == 'refresh' and _blacklist_checks() in ('all', 'refresh'):
197+
token_status = store[jti]
198+
if token_status != 'active':
199+
raise RevokedTokenError('{} has been revoked'.format)
200+
201+
160202
def jwt_required(fn):
161203
"""
162204
If you decorate a vew with this, it will ensure that the requester has a valid
@@ -167,35 +209,18 @@ def jwt_required(fn):
167209
168210
:param fn: The view function to decorate
169211
"""
212+
@_handle_callbacks_on_error
170213
@wraps(fn)
171214
def wrapper(*args, **kwargs):
172215
# Attempt to decode the token
173-
try:
174-
secret = _get_secret_key()
175-
jwt_data = _verify_jwt_from_request(secret)
176-
except NoAuthHeaderError:
177-
return current_app.jwt_manager.unauthorized_callback()
178-
except jwt.ExpiredSignatureError as e:
179-
return current_app.jwt_manager.expired_token_callback()
180-
except (InvalidHeaderError, jwt.InvalidTokenError, JWTDecodeError) as e:
181-
return current_app.jwt_manager.invalid_token_callback(str(e))
216+
jwt_data = _decode_jwt_from_request()
182217

183218
# Verify this is an access token
184219
if jwt_data['type'] != 'access':
185-
err_msg = 'Only access tokens can access this endpoint'
186-
return current_app.jwt_manager.invalid_token_callback(err_msg)
187-
188-
# TODO move this into common helper function that raises exception if
189-
# the token is blacklisted. Probably other common code I could do
190-
# this to as well
191-
#
192-
# If setup to check every request, see if this token has been revoked
193-
if _blacklist_enabled() and _blacklist_checks() == 'all':
194-
store = _get_blacklist_store()
195-
jti = jwt_data['jti']
196-
token_status = store[jti]
197-
if token_status != 'active':
198-
return current_app.jwt_manager.blacklisted_token_callback()
220+
raise WrongTokenError('Only access tokens can access this endpoint')
221+
222+
# See if the token has been revoked (based on blacklist options)
223+
_check_blacklist(jwt_data)
199224

200225
# Save the jwt in the context so that it can be accessed later by
201226
# the various endpoints that is using this decorator
@@ -214,34 +239,22 @@ def fresh_jwt_required(fn):
214239
215240
:param fn: The view function to decorate
216241
"""
242+
@_handle_callbacks_on_error
217243
@wraps(fn)
218244
def wrapper(*args, **kwargs):
219-
try:
220-
secret = _get_secret_key()
221-
jwt_data = _verify_jwt_from_request(secret)
222-
except NoAuthHeaderError:
223-
return current_app.jwt_manager.unauthorized_callback()
224-
except jwt.ExpiredSignatureError as e:
225-
return current_app.jwt_manager.expired_token_callback()
226-
except (InvalidHeaderError, jwt.InvalidTokenError, JWTDecodeError) as e:
227-
return current_app.jwt_manager.invalid_token_callback(str(e))
245+
# Attempt to decode the token
246+
jwt_data = _decode_jwt_from_request()
228247

229248
# Verify this is an access token
230249
if jwt_data['type'] != 'access':
231-
err_msg = 'Only access tokens can access this endpoint'
232-
return current_app.jwt_manager.invalid_token_callback(err_msg)
250+
raise WrongTokenError('Only access tokens can access this endpoint')
233251

234-
# If setup to check every request, see if this token has been revoked
235-
if _blacklist_enabled() and _blacklist_checks() == 'all':
236-
store = _get_blacklist_store()
237-
jti = jwt_data['jti']
238-
token_status = store[jti]
239-
if token_status != 'active':
240-
return current_app.jwt_manager.blacklisted_token_callback()
252+
# See if the token has been revoked (based on blacklist options)
253+
_check_blacklist(jwt_data)
241254

242255
# Check if the token is fresh
243256
if not jwt_data['fresh']:
244-
return current_app.jwt_manager.token_needs_refresh_callback()
257+
raise FreshTokenRequired('Fresh token required')
245258

246259
# Save the jwt in the context so that it can be accessed later by
247260
# the various endpoints that is using this decorator
@@ -272,37 +285,23 @@ def authenticate(identity):
272285
return jsonify(ret), 200
273286

274287

288+
@_handle_callbacks_on_error
275289
def refresh():
276-
# Token options
277-
secret = _get_secret_key()
278-
config = current_app.config
279-
access_expire_delta = config.get('JWT_ACCESS_TOKEN_EXPIRES', ACCESS_EXPIRES)
280-
algorithm = config.get('JWT_ALGORITHM', ALGORITHM)
281-
282-
# Get the token
283-
try:
284-
jwt_data = _verify_jwt_from_request(secret)
285-
except NoAuthHeaderError:
286-
return current_app.jwt_manager.unauthorized_callback()
287-
except jwt.ExpiredSignatureError as e:
288-
return current_app.jwt_manager.expired_token_callback()
289-
except (InvalidHeaderError, jwt.InvalidTokenError, JWTDecodeError) as e:
290-
return current_app.jwt_manager.invalid_token_callback(str(e))
290+
# Get the JWT
291+
jwt_data = _decode_jwt_from_request()
291292

292293
# verify this is a refresh token
293294
if jwt_data['type'] != 'refresh':
294-
err_msg = 'Only refresh tokens can access this endpoint'
295-
return current_app.jwt_manager.invalid_token_callback(err_msg)
295+
raise WrongTokenError('Only refresh tokens can access this endpoint')
296296

297297
# If blacklisting is enabled, see if this token has been revoked
298-
if _blacklist_enabled():
299-
store = _get_blacklist_store()
300-
jti = jwt_data['jti']
301-
token_status = store[jti]
302-
if token_status != 'active':
303-
return current_app.jwt_manager.blacklisted_token_callback()
298+
_check_blacklist(jwt_data)
304299

305300
# Create and return the new access token
301+
config = current_app.config
302+
access_expire_delta = config.get('JWT_ACCESS_TOKEN_EXPIRES', ACCESS_EXPIRES)
303+
algorithm = config.get('JWT_ALGORITHM', ALGORITHM)
304+
secret = _get_secret_key()
306305
user_claims = current_app.jwt_manager.user_claims_callback(jwt_data['identity'])
307306
identity = jwt_data['identity']
308307
access_token = _encode_access_token(identity, secret, algorithm, access_expire_delta,
@@ -317,7 +316,6 @@ def fresh_authenticate(identity):
317316
config = current_app.config
318317
access_expire_delta = config.get('JWT_ACCESS_TOKEN_EXPIRES', ACCESS_EXPIRES)
319318
algorithm = config.get('JWT_ALGORITHM', ALGORITHM)
320-
321319
user_claims = current_app.jwt_manager.user_claims_callback(identity)
322320
access_token = _encode_access_token(identity, secret, algorithm, access_expire_delta,
323321
fresh=True, user_claims=user_claims)

0 commit comments

Comments
 (0)