1515from flask_jwt_extended .config import ALGORITHM , REFRESH_EXPIRES , ACCESS_EXPIRES , \
1616 BLACKLIST_ENABLED , BLACKLIST_STORE , BLACKLIST_TOKEN_CHECKS
1717from flask_jwt_extended .exceptions import JWTEncodeError , JWTDecodeError , \
18- InvalidHeaderError , NoAuthHeaderError
19-
18+ InvalidHeaderError , NoAuthHeaderError , WrongTokenError , RevokedTokenError , \
19+ FreshTokenRequired
2020
2121# Proxy for accessing the identity of the JWT in this context
2222jwt_identity = LocalProxy (lambda : _get_identity ())
@@ -133,12 +133,7 @@ def _decode_jwt(token, secret, algorithm):
133133 return data
134134
135135
136- def _verify_jwt_from_request (secret ):
137- """
138- Returns the encoded JWT string from the request
139-
140- :return: Encoded jwt string, or None if it does not exist
141- """
136+ def _decode_jwt_from_request ():
142137 # Verify we have the auth header
143138 auth_header = request .headers .get ('Authorization' , None )
144139 if not auth_header :
@@ -154,9 +149,56 @@ def _verify_jwt_from_request(secret):
154149 raise InvalidHeaderError (msg )
155150
156151 token = parts [1 ]
152+ secret = _get_secret_key ()
157153 return _decode_jwt (token , secret , 'HS256' )
158154
159155
156+ def _handle_callbacks_on_error (fn ):
157+ """
158+ Helper decorator that will catch any exceptions we expect to encounter
159+ when dealing with a JWT, and call the appropriate callback function for
160+ handling that error. Callback functions can be set in using the *_loader
161+ methods in jwt_manager.
162+ """
163+ @wraps (fn )
164+ def wrapper (* args , ** kwargs ):
165+ m = current_app .jwt_manager
166+
167+ try :
168+ return fn (* args , ** kwargs )
169+ except NoAuthHeaderError :
170+ return m .unauthorized_callback ()
171+ except jwt .ExpiredSignatureError :
172+ return m .expired_token_callback ()
173+ except (InvalidHeaderError , jwt .InvalidTokenError , JWTDecodeError ,
174+ WrongTokenError ) as e :
175+ return m .invalid_token_callback (str (e ))
176+ except RevokedTokenError :
177+ return m .blacklisted_token_callback ()
178+ except FreshTokenRequired :
179+ return m .token_needs_refresh_callback ()
180+ return wrapper
181+
182+
183+ def _check_blacklist (jwt_data ):
184+ if not _blacklist_enabled ():
185+ return
186+
187+ store = _get_blacklist_store ()
188+ token_type = jwt_data ['type' ]
189+ jti = jwt_data ['jti' ]
190+
191+ if token_type == 'access' and _blacklist_checks () == 'all' :
192+ token_status = store [jti ]
193+ if token_status != 'active' :
194+ raise RevokedTokenError ('{} has been revoked' .format )
195+
196+ if token_type == 'refresh' and _blacklist_checks () in ('all' , 'refresh' ):
197+ token_status = store [jti ]
198+ if token_status != 'active' :
199+ raise RevokedTokenError ('{} has been revoked' .format )
200+
201+
160202def jwt_required (fn ):
161203 """
162204 If you decorate a vew with this, it will ensure that the requester has a valid
@@ -167,35 +209,18 @@ def jwt_required(fn):
167209
168210 :param fn: The view function to decorate
169211 """
212+ @_handle_callbacks_on_error
170213 @wraps (fn )
171214 def wrapper (* args , ** kwargs ):
172215 # Attempt to decode the token
173- try :
174- secret = _get_secret_key ()
175- jwt_data = _verify_jwt_from_request (secret )
176- except NoAuthHeaderError :
177- return current_app .jwt_manager .unauthorized_callback ()
178- except jwt .ExpiredSignatureError as e :
179- return current_app .jwt_manager .expired_token_callback ()
180- except (InvalidHeaderError , jwt .InvalidTokenError , JWTDecodeError ) as e :
181- return current_app .jwt_manager .invalid_token_callback (str (e ))
216+ jwt_data = _decode_jwt_from_request ()
182217
183218 # Verify this is an access token
184219 if jwt_data ['type' ] != 'access' :
185- err_msg = 'Only access tokens can access this endpoint'
186- return current_app .jwt_manager .invalid_token_callback (err_msg )
187-
188- # TODO move this into common helper function that raises exception if
189- # the token is blacklisted. Probably other common code I could do
190- # this to as well
191- #
192- # If setup to check every request, see if this token has been revoked
193- if _blacklist_enabled () and _blacklist_checks () == 'all' :
194- store = _get_blacklist_store ()
195- jti = jwt_data ['jti' ]
196- token_status = store [jti ]
197- if token_status != 'active' :
198- return current_app .jwt_manager .blacklisted_token_callback ()
220+ raise WrongTokenError ('Only access tokens can access this endpoint' )
221+
222+ # See if the token has been revoked (based on blacklist options)
223+ _check_blacklist (jwt_data )
199224
200225 # Save the jwt in the context so that it can be accessed later by
201226 # the various endpoints that is using this decorator
@@ -214,34 +239,22 @@ def fresh_jwt_required(fn):
214239
215240 :param fn: The view function to decorate
216241 """
242+ @_handle_callbacks_on_error
217243 @wraps (fn )
218244 def wrapper (* args , ** kwargs ):
219- try :
220- secret = _get_secret_key ()
221- jwt_data = _verify_jwt_from_request (secret )
222- except NoAuthHeaderError :
223- return current_app .jwt_manager .unauthorized_callback ()
224- except jwt .ExpiredSignatureError as e :
225- return current_app .jwt_manager .expired_token_callback ()
226- except (InvalidHeaderError , jwt .InvalidTokenError , JWTDecodeError ) as e :
227- return current_app .jwt_manager .invalid_token_callback (str (e ))
245+ # Attempt to decode the token
246+ jwt_data = _decode_jwt_from_request ()
228247
229248 # Verify this is an access token
230249 if jwt_data ['type' ] != 'access' :
231- err_msg = 'Only access tokens can access this endpoint'
232- return current_app .jwt_manager .invalid_token_callback (err_msg )
250+ raise WrongTokenError ('Only access tokens can access this endpoint' )
233251
234- # If setup to check every request, see if this token has been revoked
235- if _blacklist_enabled () and _blacklist_checks () == 'all' :
236- store = _get_blacklist_store ()
237- jti = jwt_data ['jti' ]
238- token_status = store [jti ]
239- if token_status != 'active' :
240- return current_app .jwt_manager .blacklisted_token_callback ()
252+ # See if the token has been revoked (based on blacklist options)
253+ _check_blacklist (jwt_data )
241254
242255 # Check if the token is fresh
243256 if not jwt_data ['fresh' ]:
244- return current_app . jwt_manager . token_needs_refresh_callback ( )
257+ raise FreshTokenRequired ( 'Fresh token required' )
245258
246259 # Save the jwt in the context so that it can be accessed later by
247260 # the various endpoints that is using this decorator
@@ -272,37 +285,23 @@ def authenticate(identity):
272285 return jsonify (ret ), 200
273286
274287
288+ @_handle_callbacks_on_error
275289def refresh ():
276- # Token options
277- secret = _get_secret_key ()
278- config = current_app .config
279- access_expire_delta = config .get ('JWT_ACCESS_TOKEN_EXPIRES' , ACCESS_EXPIRES )
280- algorithm = config .get ('JWT_ALGORITHM' , ALGORITHM )
281-
282- # Get the token
283- try :
284- jwt_data = _verify_jwt_from_request (secret )
285- except NoAuthHeaderError :
286- return current_app .jwt_manager .unauthorized_callback ()
287- except jwt .ExpiredSignatureError as e :
288- return current_app .jwt_manager .expired_token_callback ()
289- except (InvalidHeaderError , jwt .InvalidTokenError , JWTDecodeError ) as e :
290- return current_app .jwt_manager .invalid_token_callback (str (e ))
290+ # Get the JWT
291+ jwt_data = _decode_jwt_from_request ()
291292
292293 # verify this is a refresh token
293294 if jwt_data ['type' ] != 'refresh' :
294- err_msg = 'Only refresh tokens can access this endpoint'
295- return current_app .jwt_manager .invalid_token_callback (err_msg )
295+ raise WrongTokenError ('Only refresh tokens can access this endpoint' )
296296
297297 # If blacklisting is enabled, see if this token has been revoked
298- if _blacklist_enabled ():
299- store = _get_blacklist_store ()
300- jti = jwt_data ['jti' ]
301- token_status = store [jti ]
302- if token_status != 'active' :
303- return current_app .jwt_manager .blacklisted_token_callback ()
298+ _check_blacklist (jwt_data )
304299
305300 # Create and return the new access token
301+ config = current_app .config
302+ access_expire_delta = config .get ('JWT_ACCESS_TOKEN_EXPIRES' , ACCESS_EXPIRES )
303+ algorithm = config .get ('JWT_ALGORITHM' , ALGORITHM )
304+ secret = _get_secret_key ()
306305 user_claims = current_app .jwt_manager .user_claims_callback (jwt_data ['identity' ])
307306 identity = jwt_data ['identity' ]
308307 access_token = _encode_access_token (identity , secret , algorithm , access_expire_delta ,
@@ -317,7 +316,6 @@ def fresh_authenticate(identity):
317316 config = current_app .config
318317 access_expire_delta = config .get ('JWT_ACCESS_TOKEN_EXPIRES' , ACCESS_EXPIRES )
319318 algorithm = config .get ('JWT_ALGORITHM' , ALGORITHM )
320-
321319 user_claims = current_app .jwt_manager .user_claims_callback (identity )
322320 access_token = _encode_access_token (identity , secret , algorithm , access_expire_delta ,
323321 fresh = True , user_claims = user_claims )
0 commit comments