Skip to content

Commit 986c9c8

Browse files
committed
Add support for JWTs in headers and cookies in the same app
refs #26
1 parent cc6dc7d commit 986c9c8

File tree

8 files changed

+215
-58
lines changed

8 files changed

+215
-58
lines changed

docs/options.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ The available options are:
1313

1414
================================= =========================================
1515
``JWT_TOKEN_LOCATION`` Where to look for a JWT when processing a request. The options are ``'headers'`` or
16-
``'cookies'``. Defaults to ``'headers'``
16+
``'cookies'``. You can pass in a list to check more then one location: ```['headers', 'cookies']```.
17+
Defaults to ``'headers'``
1718
``JWT_HEADER_NAME`` What header to look for the JWT in a request. Only used if we are sending
1819
the JWT in via headers. Defaults to ``'Authorization'``
1920
``JWT_HEADER_TYPE`` What type of header the JWT is in. Defaults to ``'Bearer'``. This can be

examples/csrf_protection_with_cookies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
app.secret_key = 'super-secret' # Change this!
1010

1111
# Configure application to store JWTs in cookies
12-
app.config['JWT_TOKEN_LOCATION'] = 'cookies'
12+
app.config['JWT_TOKEN_LOCATION'] = ['cookies']
1313

1414
# Only allow JWT cookies to be sent over https. In production, this
1515
# should likely be True

examples/jwt_in_cookie.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Configure application to store JWTs in cookies. Whenever you make
1717
# a request to a protected endpoint, you will need to send in the
1818
# access or refresh JWT via a cookie.
19-
app.config['JWT_TOKEN_LOCATION'] = 'cookies'
19+
app.config['JWT_TOKEN_LOCATION'] = ['cookies']
2020

2121
# Set the cookie paths, so that you are only sending your access token
2222
# cookie to the access endpoints, and only sending your refresh token

flask_jwt_extended/config.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,8 @@
22
from flask import current_app
33

44

5-
# TODO support for cookies and headers at the same time. This could be useful
6-
# for using cookies in a web browser (more secure), and headers in a mobile
7-
# app (don't have to worry about csrf/xss there, and headers are easier to
8-
# manage in that environment)
9-
105
# Where to look for the JWT. Available options are cookies or headers
11-
TOKEN_LOCATION = 'headers'
6+
TOKEN_LOCATION = ['headers']
127

138
# Options for JWTs when the TOKEN_LOCATION is headers
149
HEADER_NAME = 'Authorization'
@@ -43,10 +38,14 @@
4338

4439

4540
def get_token_location():
46-
location = current_app.config.get('JWT_TOKEN_LOCATION', TOKEN_LOCATION)
47-
if location not in ['headers', 'cookies']:
48-
raise RuntimeError('JWT_LOCATION_LOCATION must be "headers" or "cookies"')
49-
return location
41+
locations = current_app.config.get('JWT_TOKEN_LOCATION', TOKEN_LOCATION)
42+
if not isinstance(locations, list):
43+
locations = [locations]
44+
for location in locations:
45+
if location not in ('headers', 'cookies'):
46+
raise RuntimeError('JWT_LOCATION_LOCATION can only contain '
47+
'"headers" and/or "cookies"')
48+
return locations
5049

5150

5251
def get_jwt_header_name():

flask_jwt_extended/utils.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _encode_access_token(identity, secret, algorithm, token_expire_delta,
8787
'type': 'access',
8888
'user_claims': user_claims,
8989
}
90-
if get_token_location() == 'cookies' and get_cookie_csrf_protect():
90+
if 'cookies' in get_token_location() and get_cookie_csrf_protect() is True:
9191
token_data['csrf'] = _create_csrf_token()
9292
encoded_token = jwt.encode(token_data, secret, algorithm).decode('utf-8')
9393

@@ -119,7 +119,7 @@ def _encode_refresh_token(identity, secret, algorithm, token_expire_delta):
119119
'identity': identity,
120120
'type': 'refresh',
121121
}
122-
if get_token_location() == 'cookies' and get_cookie_csrf_protect():
122+
if 'cookies' in get_token_location() and get_cookie_csrf_protect() is True:
123123
token_data['csrf'] = _create_csrf_token()
124124
encoded_token = jwt.encode(token_data, secret, algorithm).decode('utf-8')
125125

@@ -153,9 +153,6 @@ def _decode_jwt(token, secret, algorithm):
153153
raise JWTDecodeError("Missing or invalid claim: fresh")
154154
if 'user_claims' not in data or not isinstance(data['user_claims'], dict):
155155
raise JWTDecodeError("Missing or invalid claim: user_claims")
156-
if get_token_location() == 'cookies' and get_cookie_csrf_protect():
157-
if 'csrf' not in data or not isinstance(data['csrf'], six.string_types):
158-
raise JWTDecodeError("Missing or invalid claim: csrf")
159156
return data
160157

161158

@@ -200,17 +197,41 @@ def _decode_jwt_from_cookies(type):
200197

201198
if get_cookie_csrf_protect():
202199
csrf_header_key = get_csrf_header_name()
203-
csrf = request.headers.get(csrf_header_key, None)
204-
if not csrf or not safe_str_cmp(csrf, token['csrf']):
205-
raise NoAuthorizationError("Missing or invalid csrf double submit header")
206-
200+
csrf_token_from_header = request.headers.get(csrf_header_key, None)
201+
csrf_token_from_cookie = token.get('csrf', None)
202+
203+
# Verify the csrf tokens are present and matching
204+
if csrf_token_from_cookie is None:
205+
raise JWTDecodeError("Missing claim: 'csrf'")
206+
if not isinstance(csrf_token_from_cookie, six.string_types):
207+
raise JWTDecodeError("Invalid claim: 'csrf' (must be a string)")
208+
if csrf_token_from_header is None:
209+
raise NoAuthorizationError("Missing CSRF token in headers")
210+
if not safe_str_cmp(csrf_token_from_header, csrf_token_from_cookie):
211+
raise NoAuthorizationError("CSRF double submit tokens do not match")
207212
return token
208213

209214

210215
def _decode_jwt_from_request(type):
211-
token_location = get_token_location()
212-
if token_location == 'headers':
216+
token_locations = get_token_location()
217+
218+
# JWT can be in either headers or cookies
219+
if 'headers' in token_locations and 'cookies' in token_locations:
220+
try:
221+
return _decode_jwt_from_headers()
222+
except NoAuthorizationError:
223+
pass
224+
try:
225+
return _decode_jwt_from_cookies(type)
226+
except NoAuthorizationError:
227+
pass
228+
raise NoAuthorizationError("Missing JWT in header and cookies")
229+
230+
# JWT can only be in headers
231+
elif 'headers' in token_locations:
213232
return _decode_jwt_from_headers()
233+
234+
# JWT can only be in cookie
214235
else:
215236
return _decode_jwt_from_cookies(type)
216237

tests/test_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def setUp(self):
2525

2626
def test_default_configs(self):
2727
with self.app.test_request_context():
28-
self.assertEqual(get_token_location(), 'headers')
28+
self.assertEqual(get_token_location(), ['headers'])
2929
self.assertEqual(get_jwt_header_name(), 'Authorization')
3030
self.assertEqual(get_jwt_header_type(), 'Bearer')
3131

@@ -69,7 +69,7 @@ def test_override_configs(self):
6969
self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'all'
7070

7171
with self.app.test_request_context():
72-
self.assertEqual(get_token_location(), 'cookies')
72+
self.assertEqual(get_token_location(), ['cookies'])
7373
self.assertEqual(get_jwt_header_name(), 'Auth')
7474
self.assertEqual(get_jwt_header_type(), 'JWT')
7575

tests/test_jwt_encode_decode.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -302,36 +302,6 @@ def test_decode_invalid_jwt(self):
302302
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
303303
_decode_jwt(encoded_token, 'secret', 'HS256')
304304

305-
# Missing and bad csrf tokens
306-
self.app.config['JWT_TOKEN_LOCATION'] = 'cookies'
307-
self.app.config['JWT_COOKIE_CSRF_PROTECTION'] = True
308-
with self.app.test_request_context():
309-
now = datetime.utcnow()
310-
with self.assertRaises(JWTDecodeError):
311-
token_data = {
312-
'exp': now + timedelta(minutes=5),
313-
'iat': now,
314-
'nbf': now,
315-
'jti': 'banana',
316-
'identity': 'banana',
317-
'type': 'refresh',
318-
}
319-
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
320-
_decode_jwt(encoded_token, 'secret', 'HS256')
321-
322-
with self.assertRaises(JWTDecodeError):
323-
token_data = {
324-
'exp': now + timedelta(minutes=5),
325-
'iat': now,
326-
'nbf': now,
327-
'jti': 'banana',
328-
'identity': 'banana',
329-
'type': 'refresh',
330-
'csrf': True
331-
}
332-
encoded_token = jwt.encode(token_data, 'secret', 'HS256').decode('utf-8')
333-
_decode_jwt(encoded_token, 'secret', 'HS256')
334-
335305
def test_create_access_token_with_object(self):
336306
# Complex object to test building a JWT from. Normally if you are using
337307
# this functionality, this is something that would be retrieved from

tests/test_protected_endpoints.py

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ def setUp(self):
358358
self.app.config['JWT_TOKEN_LOCATION'] = 'cookies'
359359
self.app.config['JWT_ACCESS_COOKIE_PATH'] = '/api/'
360360
self.app.config['JWT_REFRESH_COOKIE_PATH'] = '/auth/refresh'
361+
self.app.config['JWT_ACCESS_COOKIE_NAME'] = 'access_token_cookie'
362+
self.app.config['JWT_ALGORITHM'] = 'HS256'
361363
self.jwt_manager = JWTManager(self.app)
362364
self.client = self.app.test_client()
363365

@@ -549,9 +551,17 @@ def test_access_endpoints_with_cookies_and_csrf(self):
549551
self.assertEqual(status_code, 401)
550552
self.assertIn('msg', data)
551553

552-
# Try with logged in and bad double submit token
554+
# Try with logged in and bad header name for double submit token
555+
response = self.client.get('/api/protected',
556+
headers={'bad-header-name': 'banana'})
557+
status_code = response.status_code
558+
data = json.loads(response.get_data(as_text=True))
559+
self.assertEqual(status_code, 401)
560+
self.assertIn('msg', data)
561+
562+
# Try with logged in and bad header data for double submit token
553563
response = self.client.get('/api/protected',
554-
headers={'X-CSRF-ACCESS-TOKEN': 'banana'})
564+
headers={'X-CSRF-TOKEN': 'banana'})
555565
status_code = response.status_code
556566
data = json.loads(response.get_data(as_text=True))
557567
self.assertEqual(status_code, 401)
@@ -564,3 +574,159 @@ def test_access_endpoints_with_cookies_and_csrf(self):
564574
data = json.loads(response.get_data(as_text=True))
565575
self.assertEqual(status_code, 200)
566576
self.assertEqual(data, {'msg': 'hello world'})
577+
578+
def test_access_endpoints_with_cookie_missing_csrf_field(self):
579+
# Test accessing a csrf protected endpoint with a cookie that does not
580+
# have a csrf token in it
581+
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False
582+
self._login()
583+
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
584+
585+
response = self.client.get('/api/protected')
586+
status_code = response.status_code
587+
data = json.loads(response.get_data(as_text=True))
588+
self.assertEqual(status_code, 422)
589+
self.assertIn('msg', data)
590+
591+
def test_access_endpoints_with_cookie_csrf_claim_not_string(self):
592+
now = datetime.utcnow()
593+
token_data = {
594+
'exp': now + timedelta(minutes=5),
595+
'iat': now,
596+
'nbf': now,
597+
'jti': 'banana',
598+
'identity': 'banana',
599+
'type': 'refresh',
600+
'csrf': 404
601+
}
602+
secret = self.app.secret_key
603+
algorithm = self.app.config['JWT_ALGORITHM']
604+
encoded_token = jwt.encode(token_data, secret, algorithm).decode('utf-8')
605+
access_cookie_key = self.app.config['JWT_ACCESS_COOKIE_NAME']
606+
self.client.set_cookie('localhost', access_cookie_key, encoded_token)
607+
608+
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
609+
response = self.client.get('/api/protected')
610+
status_code = response.status_code
611+
data = json.loads(response.get_data(as_text=True))
612+
self.assertEqual(status_code, 422)
613+
self.assertIn('msg', data)
614+
615+
616+
class TestEndpointsWithHeadersAndCookies(unittest.TestCase):
617+
618+
def setUp(self):
619+
self.app = Flask(__name__)
620+
self.app.secret_key = 'super=secret'
621+
self.app.config['JWT_TOKEN_LOCATION'] = ['cookies', 'headers']
622+
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
623+
self.app.config['JWT_ACCESS_COOKIE_PATH'] = '/api/'
624+
self.app.config['JWT_REFRESH_COOKIE_PATH'] = '/auth/refresh'
625+
self.jwt_manager = JWTManager(self.app)
626+
self.client = self.app.test_client()
627+
628+
@self.app.route('/auth/login_cookies', methods=['POST'])
629+
def login_cookies():
630+
# Create the tokens we will be sending back to the user
631+
access_token = create_access_token(identity='test')
632+
refresh_token = create_refresh_token(identity='test')
633+
634+
# Set the JWTs and the CSRF double submit protection cookies in this response
635+
resp = jsonify({'login': True})
636+
set_access_cookies(resp, access_token)
637+
set_refresh_cookies(resp, refresh_token)
638+
return resp, 200
639+
640+
@self.app.route('/auth/login_headers', methods=['POST'])
641+
def login_headers():
642+
ret = {
643+
'access_token': create_access_token('test', fresh=True),
644+
'refresh_token': create_refresh_token('test')
645+
}
646+
return jsonify(ret), 200
647+
648+
@self.app.route('/api/protected')
649+
@jwt_required
650+
def protected():
651+
return jsonify({'msg': "hello world"})
652+
653+
def _jwt_post(self, url, jwt):
654+
response = self.client.post(url, content_type='application/json',
655+
headers={'Authorization': 'Bearer {}'.format(jwt)})
656+
status_code = response.status_code
657+
data = json.loads(response.get_data(as_text=True))
658+
return status_code, data
659+
660+
def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'):
661+
header_type = '{} {}'.format(header_type, jwt).strip()
662+
response = self.client.get(url, headers={header_name: header_type})
663+
status_code = response.status_code
664+
data = json.loads(response.get_data(as_text=True))
665+
return status_code, data
666+
667+
def _login_cookies(self):
668+
resp = self.client.post('/auth/login_cookies')
669+
index = 1
670+
671+
access_cookie_str = resp.headers[index][1]
672+
access_cookie_key = access_cookie_str.split('=')[0]
673+
access_cookie_value = "".join(access_cookie_str.split('=')[1:])
674+
self.client.set_cookie('localhost', access_cookie_key, access_cookie_value)
675+
index += 1
676+
677+
if self.app.config['JWT_COOKIE_CSRF_PROTECT']:
678+
access_csrf_str = resp.headers[index][1]
679+
access_csrf_key = access_csrf_str.split('=')[0]
680+
access_csrf_value = "".join(access_csrf_str.split('=')[1:])
681+
self.client.set_cookie('localhost', access_csrf_key, access_csrf_value)
682+
index += 1
683+
access_csrf = access_csrf_value.split(';')[0]
684+
else:
685+
access_csrf = ""
686+
687+
refresh_cookie_str = resp.headers[index][1]
688+
refresh_cookie_key = refresh_cookie_str.split('=')[0]
689+
refresh_cookie_value = "".join(refresh_cookie_str.split('=')[1:])
690+
self.client.set_cookie('localhost', refresh_cookie_key, refresh_cookie_value)
691+
index += 1
692+
693+
if self.app.config['JWT_COOKIE_CSRF_PROTECT']:
694+
refresh_csrf_str = resp.headers[index][1]
695+
refresh_csrf_key = refresh_csrf_str.split('=')[0]
696+
refresh_csrf_value = "".join(refresh_csrf_str.split('=')[1:])
697+
self.client.set_cookie('localhost', refresh_csrf_key, refresh_csrf_value)
698+
refresh_csrf = refresh_csrf_value.split(';')[0]
699+
else:
700+
refresh_csrf = ""
701+
702+
return access_csrf, refresh_csrf
703+
704+
def _login_headers(self):
705+
resp = self.client.post('/auth/login_headers')
706+
data = json.loads(resp.get_data(as_text=True))
707+
return data['access_token'], data['refresh_token']
708+
709+
def test_accessing_endpoint_with_headers(self):
710+
access_token, _ = self._login_headers()
711+
header_type = '{} {}'.format('Bearer', access_token).strip()
712+
response = self.client.get('/api/protected', headers={'Authorization': header_type})
713+
status_code = response.status_code
714+
data = json.loads(response.get_data(as_text=True))
715+
self.assertEqual(status_code, 200)
716+
self.assertEqual(data, {'msg': 'hello world'})
717+
718+
def test_accessing_endpoint_with_cookies(self):
719+
access_csrf, _ = self._login_cookies()
720+
response = self.client.get('/api/protected',
721+
headers={'X-CSRF-TOKEN': access_csrf})
722+
status_code = response.status_code
723+
data = json.loads(response.get_data(as_text=True))
724+
self.assertEqual(status_code, 200)
725+
self.assertEqual(data, {'msg': 'hello world'})
726+
727+
def test_accessing_endpoint_without_jwt(self):
728+
response = self.client.get('/api/protected')
729+
status_code = response.status_code
730+
data = json.loads(response.get_data(as_text=True))
731+
self.assertEqual(status_code, 401)
732+
self.assertIn('msg', data)

0 commit comments

Comments
 (0)