@@ -358,6 +358,8 @@ def setUp(self):
358358 self .app .config ['JWT_TOKEN_LOCATION' ] = 'cookies'
359359 self .app .config ['JWT_ACCESS_COOKIE_PATH' ] = '/api/'
360360 self .app .config ['JWT_REFRESH_COOKIE_PATH' ] = '/auth/refresh'
361+ self .app .config ['JWT_ACCESS_COOKIE_NAME' ] = 'access_token_cookie'
362+ self .app .config ['JWT_ALGORITHM' ] = 'HS256'
361363 self .jwt_manager = JWTManager (self .app )
362364 self .client = self .app .test_client ()
363365
@@ -549,9 +551,17 @@ def test_access_endpoints_with_cookies_and_csrf(self):
549551 self .assertEqual (status_code , 401 )
550552 self .assertIn ('msg' , data )
551553
552- # Try with logged in and bad double submit token
554+ # Try with logged in and bad header name for double submit token
555+ response = self .client .get ('/api/protected' ,
556+ headers = {'bad-header-name' : 'banana' })
557+ status_code = response .status_code
558+ data = json .loads (response .get_data (as_text = True ))
559+ self .assertEqual (status_code , 401 )
560+ self .assertIn ('msg' , data )
561+
562+ # Try with logged in and bad header data for double submit token
553563 response = self .client .get ('/api/protected' ,
554- headers = {'X-CSRF-ACCESS- TOKEN' : 'banana' })
564+ headers = {'X-CSRF-TOKEN' : 'banana' })
555565 status_code = response .status_code
556566 data = json .loads (response .get_data (as_text = True ))
557567 self .assertEqual (status_code , 401 )
@@ -564,3 +574,159 @@ def test_access_endpoints_with_cookies_and_csrf(self):
564574 data = json .loads (response .get_data (as_text = True ))
565575 self .assertEqual (status_code , 200 )
566576 self .assertEqual (data , {'msg' : 'hello world' })
577+
578+ def test_access_endpoints_with_cookie_missing_csrf_field (self ):
579+ # Test accessing a csrf protected endpoint with a cookie that does not
580+ # have a csrf token in it
581+ self .app .config ['JWT_COOKIE_CSRF_PROTECT' ] = False
582+ self ._login ()
583+ self .app .config ['JWT_COOKIE_CSRF_PROTECT' ] = True
584+
585+ response = self .client .get ('/api/protected' )
586+ status_code = response .status_code
587+ data = json .loads (response .get_data (as_text = True ))
588+ self .assertEqual (status_code , 422 )
589+ self .assertIn ('msg' , data )
590+
591+ def test_access_endpoints_with_cookie_csrf_claim_not_string (self ):
592+ now = datetime .utcnow ()
593+ token_data = {
594+ 'exp' : now + timedelta (minutes = 5 ),
595+ 'iat' : now ,
596+ 'nbf' : now ,
597+ 'jti' : 'banana' ,
598+ 'identity' : 'banana' ,
599+ 'type' : 'refresh' ,
600+ 'csrf' : 404
601+ }
602+ secret = self .app .secret_key
603+ algorithm = self .app .config ['JWT_ALGORITHM' ]
604+ encoded_token = jwt .encode (token_data , secret , algorithm ).decode ('utf-8' )
605+ access_cookie_key = self .app .config ['JWT_ACCESS_COOKIE_NAME' ]
606+ self .client .set_cookie ('localhost' , access_cookie_key , encoded_token )
607+
608+ self .app .config ['JWT_COOKIE_CSRF_PROTECT' ] = True
609+ response = self .client .get ('/api/protected' )
610+ status_code = response .status_code
611+ data = json .loads (response .get_data (as_text = True ))
612+ self .assertEqual (status_code , 422 )
613+ self .assertIn ('msg' , data )
614+
615+
616+ class TestEndpointsWithHeadersAndCookies (unittest .TestCase ):
617+
618+ def setUp (self ):
619+ self .app = Flask (__name__ )
620+ self .app .secret_key = 'super=secret'
621+ self .app .config ['JWT_TOKEN_LOCATION' ] = ['cookies' , 'headers' ]
622+ self .app .config ['JWT_COOKIE_CSRF_PROTECT' ] = True
623+ self .app .config ['JWT_ACCESS_COOKIE_PATH' ] = '/api/'
624+ self .app .config ['JWT_REFRESH_COOKIE_PATH' ] = '/auth/refresh'
625+ self .jwt_manager = JWTManager (self .app )
626+ self .client = self .app .test_client ()
627+
628+ @self .app .route ('/auth/login_cookies' , methods = ['POST' ])
629+ def login_cookies ():
630+ # Create the tokens we will be sending back to the user
631+ access_token = create_access_token (identity = 'test' )
632+ refresh_token = create_refresh_token (identity = 'test' )
633+
634+ # Set the JWTs and the CSRF double submit protection cookies in this response
635+ resp = jsonify ({'login' : True })
636+ set_access_cookies (resp , access_token )
637+ set_refresh_cookies (resp , refresh_token )
638+ return resp , 200
639+
640+ @self .app .route ('/auth/login_headers' , methods = ['POST' ])
641+ def login_headers ():
642+ ret = {
643+ 'access_token' : create_access_token ('test' , fresh = True ),
644+ 'refresh_token' : create_refresh_token ('test' )
645+ }
646+ return jsonify (ret ), 200
647+
648+ @self .app .route ('/api/protected' )
649+ @jwt_required
650+ def protected ():
651+ return jsonify ({'msg' : "hello world" })
652+
653+ def _jwt_post (self , url , jwt ):
654+ response = self .client .post (url , content_type = 'application/json' ,
655+ headers = {'Authorization' : 'Bearer {}' .format (jwt )})
656+ status_code = response .status_code
657+ data = json .loads (response .get_data (as_text = True ))
658+ return status_code , data
659+
660+ def _jwt_get (self , url , jwt , header_name = 'Authorization' , header_type = 'Bearer' ):
661+ header_type = '{} {}' .format (header_type , jwt ).strip ()
662+ response = self .client .get (url , headers = {header_name : header_type })
663+ status_code = response .status_code
664+ data = json .loads (response .get_data (as_text = True ))
665+ return status_code , data
666+
667+ def _login_cookies (self ):
668+ resp = self .client .post ('/auth/login_cookies' )
669+ index = 1
670+
671+ access_cookie_str = resp .headers [index ][1 ]
672+ access_cookie_key = access_cookie_str .split ('=' )[0 ]
673+ access_cookie_value = "" .join (access_cookie_str .split ('=' )[1 :])
674+ self .client .set_cookie ('localhost' , access_cookie_key , access_cookie_value )
675+ index += 1
676+
677+ if self .app .config ['JWT_COOKIE_CSRF_PROTECT' ]:
678+ access_csrf_str = resp .headers [index ][1 ]
679+ access_csrf_key = access_csrf_str .split ('=' )[0 ]
680+ access_csrf_value = "" .join (access_csrf_str .split ('=' )[1 :])
681+ self .client .set_cookie ('localhost' , access_csrf_key , access_csrf_value )
682+ index += 1
683+ access_csrf = access_csrf_value .split (';' )[0 ]
684+ else :
685+ access_csrf = ""
686+
687+ refresh_cookie_str = resp .headers [index ][1 ]
688+ refresh_cookie_key = refresh_cookie_str .split ('=' )[0 ]
689+ refresh_cookie_value = "" .join (refresh_cookie_str .split ('=' )[1 :])
690+ self .client .set_cookie ('localhost' , refresh_cookie_key , refresh_cookie_value )
691+ index += 1
692+
693+ if self .app .config ['JWT_COOKIE_CSRF_PROTECT' ]:
694+ refresh_csrf_str = resp .headers [index ][1 ]
695+ refresh_csrf_key = refresh_csrf_str .split ('=' )[0 ]
696+ refresh_csrf_value = "" .join (refresh_csrf_str .split ('=' )[1 :])
697+ self .client .set_cookie ('localhost' , refresh_csrf_key , refresh_csrf_value )
698+ refresh_csrf = refresh_csrf_value .split (';' )[0 ]
699+ else :
700+ refresh_csrf = ""
701+
702+ return access_csrf , refresh_csrf
703+
704+ def _login_headers (self ):
705+ resp = self .client .post ('/auth/login_headers' )
706+ data = json .loads (resp .get_data (as_text = True ))
707+ return data ['access_token' ], data ['refresh_token' ]
708+
709+ def test_accessing_endpoint_with_headers (self ):
710+ access_token , _ = self ._login_headers ()
711+ header_type = '{} {}' .format ('Bearer' , access_token ).strip ()
712+ response = self .client .get ('/api/protected' , headers = {'Authorization' : header_type })
713+ status_code = response .status_code
714+ data = json .loads (response .get_data (as_text = True ))
715+ self .assertEqual (status_code , 200 )
716+ self .assertEqual (data , {'msg' : 'hello world' })
717+
718+ def test_accessing_endpoint_with_cookies (self ):
719+ access_csrf , _ = self ._login_cookies ()
720+ response = self .client .get ('/api/protected' ,
721+ headers = {'X-CSRF-TOKEN' : access_csrf })
722+ status_code = response .status_code
723+ data = json .loads (response .get_data (as_text = True ))
724+ self .assertEqual (status_code , 200 )
725+ self .assertEqual (data , {'msg' : 'hello world' })
726+
727+ def test_accessing_endpoint_without_jwt (self ):
728+ response = self .client .get ('/api/protected' )
729+ status_code = response .status_code
730+ data = json .loads (response .get_data (as_text = True ))
731+ self .assertEqual (status_code , 401 )
732+ self .assertIn ('msg' , data )
0 commit comments