Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion fasthtml/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@
'fasthtml.jupyter.wait_port_free': ('api/jupyter.html#wait_port_free', 'fasthtml/jupyter.py'),
'fasthtml.jupyter.ws_client': ('api/jupyter.html#ws_client', 'fasthtml/jupyter.py')},
'fasthtml.live_reload': {},
'fasthtml.oauth': { 'fasthtml.oauth.Auth0AppClient': ('api/oauth.html#auth0appclient', 'fasthtml/oauth.py'),
'fasthtml.oauth': { 'fasthtml.oauth.AppleAppClient': ('api/oauth.html#appleappclient', 'fasthtml/oauth.py'),
'fasthtml.oauth.AppleAppClient.__init__': ('api/oauth.html#appleappclient.__init__', 'fasthtml/oauth.py'),
'fasthtml.oauth.AppleAppClient.client_secret': ( 'api/oauth.html#appleappclient.client_secret',
'fasthtml/oauth.py'),
'fasthtml.oauth.AppleAppClient.get_info': ('api/oauth.html#appleappclient.get_info', 'fasthtml/oauth.py'),
'fasthtml.oauth.Auth0AppClient': ('api/oauth.html#auth0appclient', 'fasthtml/oauth.py'),
'fasthtml.oauth.Auth0AppClient.__init__': ('api/oauth.html#auth0appclient.__init__', 'fasthtml/oauth.py'),
'fasthtml.oauth.Auth0AppClient._fetch_openid_config': ( 'api/oauth.html#auth0appclient._fetch_openid_config',
'fasthtml/oauth.py'),
Expand Down
36 changes: 32 additions & 4 deletions fasthtml/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

# %% auto 0
__all__ = ['http_patterns', 'GoogleAppClient', 'GitHubAppClient', 'HuggingFaceClient', 'DiscordAppClient', 'Auth0AppClient',
'get_host', 'redir_url', 'url_match', 'OAuth', 'load_creds']
'AppleAppClient', 'get_host', 'redir_url', 'url_match', 'OAuth', 'load_creds']

# %% ../nbs/api/08_oauth.ipynb
from .common import *
from oauthlib.oauth2 import WebApplicationClient
from urllib.parse import urlparse, urlencode, parse_qs, quote, unquote
import secrets, httpx
import secrets, httpx, time

# %% ../nbs/api/08_oauth.ipynb
class _AppClient(WebApplicationClient):
Expand Down Expand Up @@ -112,6 +112,33 @@ def login_link(self, req):
d = dict(response_type="code", client_id=self.client_id, scope=self.scope, redirect_uri=redir_url(req, self.redirect_uri))
return f"{self.base_url}?{urlencode(d)}"

# %% ../nbs/api/08_oauth.ipynb
class AppleAppClient(_AppClient):
"A `WebApplicationClient` for Apple Sign In"
base_url = "https://appleid.apple.com/auth/authorize"
token_url = "https://appleid.apple.com/auth/token"

def __init__(self, client_id, key_id, team_id, private_key, code=None, scope=None, **kwargs):
if not scope: scope = ["name", "email"]
super().__init__(client_id, client_secret=None, code=code, scope=scope, **kwargs)
self.key_id, self.team_id, self.private_key = key_id, team_id, private_key

@property
def client_secret(self):
import jwt
now = int(time.time())
payload = dict(iss=self.team_id, iat=now, exp=now + 86400 * 180, aud='https://appleid.apple.com', sub=self.client_id)
return jwt.encode(payload, self.private_key, algorithm='ES256', headers={'kid': self.key_id})

@client_secret.setter
def client_secret(self, value): pass

def get_info(self, token=None):
"Decode user info from the ID token"
import jwt
if token: self.token = token
return jwt.decode(self.token.get('id_token'), options={"verify_signature": False})

# %% ../nbs/api/08_oauth.ipynb
@patch
def login_link(self:WebApplicationClient, redirect_uri, scope=None, state=None, **kwargs):
Expand Down Expand Up @@ -171,8 +198,9 @@ def url_match(request, patterns=http_patterns):

# %% ../nbs/api/08_oauth.ipynb
class OAuth:
def __init__(self, app, cli, skip=None, redir_path='/redirect', error_path='/error', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns):
def __init__(self, app, cli, skip=None, redir_path='/redirect', error_path='/error', logout_path='/logout', login_path='/login', https=True, http_patterns=http_patterns, redir_method='get'):
if not skip: skip = [redir_path,error_path,login_path]
redir_handler = app.post if redir_method == 'post' else app.get
store_attr()
def before(req, session):
if 'auth' not in req.scope: req.scope['auth'] = session.get('auth')
Expand All @@ -182,7 +210,7 @@ def before(req, session):
if res: return res
app.before.append(Beforeware(before, skip=skip))

@app.get(redir_path)
@redir_handler(redir_path)
def redirect(req, session, code:str=None, error:str=None, state:str=None):
if not code:
session['oauth_error']=error
Expand Down
Loading