55from api .resolvers import Mutation , Query , Subscription
66from auth import (
77 UnauthenticatedRedirect ,
8- authenticate_connection ,
9- get_token_source ,
8+ get_user_payload ,
109 unauthenticated_redirect_handler ,
1110)
1211from config import (
1918from database .models .user import User
2019from database .session import get_db_session
2120from fastapi import Depends , FastAPI , HTTPException , Request
22- from fastapi .responses import RedirectResponse
21+ from fastapi .responses import HTMLResponse , RedirectResponse
22+ from graphql import FieldNode
2323from sqlalchemy import select
2424from starlette .requests import HTTPConnection
25- from starlette .status import HTTP_400_BAD_REQUEST
2625from strawberry import Schema
26+ from strawberry .extensions import SchemaExtension
2727from strawberry .fastapi import GraphQLRouter
2828
2929logger = logging .getLogger (LOGGER )
3030
3131
32+ class GlobalAuthExtension (SchemaExtension ):
33+ def on_execute (self ):
34+ execution_context = self .execution_context
35+ user = execution_context .context .user
36+
37+ if user :
38+ yield
39+ return
40+
41+ document = execution_context .graphql_document
42+ if document :
43+ for definition in document .definitions :
44+ if definition .kind == "operation_definition" :
45+ for selection in definition .selection_set .selections :
46+ if not isinstance (selection , FieldNode ):
47+ raise HTTPException (
48+ status_code = 401 ,
49+ detail = "Not authenticated" ,
50+ )
51+ if not selection .name .value .startswith ("__" ):
52+ raise HTTPException (
53+ status_code = 401 ,
54+ detail = "Not authenticated" ,
55+ )
56+
57+ yield
58+
59+
3260async def get_context (
3361 connection : HTTPConnection ,
34- token : str | None = Depends (get_token_source ),
3562 session = Depends (get_db_session ),
3663) -> Context :
37- user_payload = await authenticate_connection (connection , token )
64+ user_payload = get_user_payload (connection )
3865
39- user_id = user_payload .get ("sub" )
40- username = user_payload .get ("preferred_username" ) or user_payload .get (
41- "name" ,
42- )
43- firstname = user_payload .get ("given_name" )
44- lastname = user_payload .get ("family_name" )
66+ db_user = None
4567
46- if not ( user_id and username and firstname and lastname ) :
47- raise HTTPException (
48- status_code = HTTP_400_BAD_REQUEST ,
49- detail = "Missing required user details. " ,
68+ if user_payload :
69+ user_id = user_payload . get ( "sub" )
70+ username = user_payload . get ( "preferred_username" ) or user_payload . get (
71+ "name " ,
5072 )
73+ firstname = user_payload .get ("given_name" )
74+ lastname = user_payload .get ("family_name" )
5175
52- result = await session .execute (select (User ).where (User .id == user_id ))
53- db_user = result .scalars ().first ()
76+ if user_id :
77+ result = await session .execute (
78+ select (User ).where (User .id == user_id ),
79+ )
80+ db_user = result .scalars ().first ()
5481
55- if not db_user :
56- db_user = User (
57- id = user_id ,
58- name = username ,
59- firstname = firstname ,
60- lastname = lastname ,
61- title = "User" ,
62- )
63- session .add (db_user )
64- elif db_user .name != username or db_user .firstname != firstname :
65- db_user .name = username
66- db_user .firstname = firstname
67- db_user .lastname = lastname
82+ if not db_user :
83+ db_user = User (
84+ id = user_id ,
85+ name = username ,
86+ firstname = firstname ,
87+ lastname = lastname ,
88+ title = "User" ,
89+ )
90+ session .add (db_user )
91+ elif (
92+ db_user .name != username
93+ or db_user .firstname != firstname
94+ or db_user .lastname != lastname
95+ ):
96+ db_user .name = username
97+ db_user .firstname = firstname
98+ db_user .lastname = lastname
6899
69- await session .commit ()
70- await session .refresh (db_user )
100+ await session .commit ()
101+ await session .refresh (db_user )
71102
72103 return Context (db = session , user = db_user )
73104
@@ -76,8 +107,107 @@ async def get_context(
76107 query = Query ,
77108 mutation = Mutation ,
78109 subscription = Subscription ,
110+ extensions = [GlobalAuthExtension ],
79111)
80- graphql_app = GraphQLRouter (
112+
113+
114+ class AuthedGraphQLRouter (GraphQLRouter ):
115+ async def render_graphql_ide (self , request : Request ) -> HTMLResponse :
116+ response = await super ().render_graphql_ide (request )
117+
118+ redirect_uri = f"{ request .base_url } callback"
119+ login_url = (
120+ f"{ ISSUER_URI } /protocol/openid-connect/auth"
121+ f"?client_id={ CLIENT_ID } "
122+ f"&response_type=code"
123+ f"&scope=openid profile email"
124+ f"&redirect_uri={ redirect_uri } "
125+ )
126+ logout_url = f"{ request .base_url } logout"
127+
128+ is_authenticated = (
129+ "true" if request .cookies .get ("access_token" ) else "false"
130+ )
131+
132+ injection_script = f"""
133+ <script>
134+ (function() {{
135+ var loginUrl = "{ login_url } ";
136+ var logoutUrl = "{ logout_url } ";
137+ var isAuthenticated = { is_authenticated } ;
138+
139+ function injectAuthButton() {{
140+ var sidebars = document.querySelectorAll('.graphiql-sidebar-section');
141+ var sidebar = sidebars[0];
142+
143+ if (sidebar && !document.getElementById('custom-auth-button')) {{
144+ var button = document.createElement('button');
145+ button.id = 'custom-auth-button';
146+ button.className = 'graphiql-un-styled';
147+ button.type = 'button';
148+
149+ if (isAuthenticated) {{
150+ button.setAttribute('aria-label', 'Logout');
151+ button.title = 'Logout';
152+ button.innerHTML = `
153+ <svg height="1em" viewBox="0 0 24 24" fill="none"
154+ stroke="currentColor" stroke-width="1.5"
155+ stroke-linecap="round" stroke-linejoin="round"
156+ xmlns="http://www.w3.org/2000/svg">
157+ <path d="M9 21H5a2 2 0 0 1-2-2V5a2 2 0 0 1 2-2h4"></path>
158+ <polyline points="16 17 21 12 16 7"></polyline>
159+ <line x1="21" y1="12" x2="9" y2="12"></line>
160+ </svg>
161+ `;
162+ button.onclick = function() {{
163+ window.location.href = logoutUrl;
164+ }};
165+ }} else {{
166+ button.setAttribute('aria-label', 'Login with OIDC');
167+ button.title = 'Login with OIDC';
168+ button.innerHTML = `
169+ <svg height="1em" viewBox="0 0 24 24" fill="none"
170+ stroke="currentColor" stroke-width="1.5"
171+ stroke-linecap="round" stroke-linejoin="round"
172+ xmlns="http://www.w3.org/2000/svg">
173+ <path d="M15 3h4a2 2 0 0 1 2 2v14a2 2 0 0 1-2 2h-4"></path>
174+ <polyline points="10 17 15 12 10 7"></polyline>
175+ <line x1="15" y1="12" x2="3" y2="12"></line>
176+ </svg>
177+ `;
178+ button.onclick = function() {{
179+ window.location.href = loginUrl;
180+ }};
181+ }}
182+
183+ sidebar.appendChild(button);
184+ }}
185+ }}
186+
187+ var observer = new MutationObserver(function(mutations) {{
188+ injectAuthButton();
189+ }});
190+
191+ observer.observe(document.body, {{
192+ childList: true,
193+ subtree: true
194+ }});
195+
196+ injectAuthButton();
197+ }})();
198+ </script>
199+ """
200+
201+ html_content = response .body .decode ("utf-8" )
202+ new_html = html_content .replace (
203+ "</body>" ,
204+ f"{ injection_script } </body>" ,
205+ )
206+
207+ return HTMLResponse (new_html )
208+
209+
210+ graphql_app = AuthedGraphQLRouter (
81211 schema ,
82212 context_getter = get_context ,
83213 graphql_ide = IS_DEV ,
@@ -97,6 +227,19 @@ async def get_context(
97227)
98228
99229
230+ @app .get ("/logout" )
231+ def logout (_ : Request ):
232+ response = RedirectResponse (url = "/graphql" )
233+ response .delete_cookie (
234+ "access_token" ,
235+ path = "/" ,
236+ secure = True ,
237+ httponly = True ,
238+ samesite = "lax" ,
239+ )
240+ return response
241+
242+
100243@app .get ("/callback" )
101244def oauth_callback (code : str , request : Request ):
102245 token_endpoint = f"{ ISSUER_URI } /protocol/openid-connect/token"
@@ -123,6 +266,8 @@ def oauth_callback(code: str, request: Request):
123266 access_token ,
124267 httponly = True ,
125268 max_age = 3600 ,
269+ samesite = "lax" ,
270+ secure = True ,
126271 )
127272 return resp
128273
0 commit comments