Skip to content

Commit 16fbf20

Browse files
Shield: Add sign out implementation
1 parent 8a59e76 commit 16fbf20

File tree

11 files changed

+131
-74
lines changed

11 files changed

+131
-74
lines changed

examples/leptos-actix/src/home.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub fn HomePage() -> impl IntoView {
2020
{move || Suspend::new(async move { match user.await {
2121
Ok(user) => Either::Left(match user {
2222
Some(user) => Either::Left(view! {
23-
{user.id}
23+
<p><b>User ID:</b> {user.id}</p>
2424

2525
<A href="/auth/sign-out">
2626
<button>"Sign out"</button>

examples/leptos-axum/src/home.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pub fn HomePage() -> impl IntoView {
2020
{move || Suspend::new(async move { match user.await {
2121
Ok(user) => Either::Left(match user {
2222
Some(user) => Either::Left(view! {
23-
{user.id}
23+
<p><b>User ID:</b> {user.id}</p>
2424

2525
<A href="/auth/sign-out">
2626
<button>"Sign out"</button>

packages/core/shield/src/session.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,18 @@ impl Session {
4343

4444
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
4545
pub struct SessionData {
46-
pub user_id: Option<String>,
46+
pub authentication: Option<Authentication>,
4747

48-
// TODO: allow arbitrary data to be stored by providers?
48+
// TODO: Allow arbitrary data to be stored by providers?
4949
pub csrf: Option<String>,
5050
pub nonce: Option<String>,
5151
pub verifier: Option<String>,
52+
pub oidc_connection_id: Option<String>,
53+
}
54+
55+
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
56+
pub struct Authentication {
57+
pub provider_id: String,
58+
pub subprovider_id: Option<String>,
59+
pub user_id: String,
5260
}

packages/core/shield/src/shield.rs

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use futures::future::try_join_all;
44
use tracing::debug;
55

66
use crate::{
7-
error::{ProviderError, ShieldError},
7+
error::{ProviderError, SessionError, ShieldError},
88
provider::{Provider, Subprovider, SubproviderVisualisation},
99
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
1010
response::Response,
@@ -107,19 +107,39 @@ impl<U: User> Shield<U> {
107107
provider.sign_in_callback(request, session).await
108108
}
109109

110-
pub async fn sign_out(
111-
&self,
112-
request: SignOutRequest,
113-
session: Session,
114-
) -> Result<Response, ShieldError> {
115-
debug!("sign out {:?}", request);
110+
pub async fn sign_out(&self, session: Session) -> Result<Response, ShieldError> {
111+
debug!("sign out");
116112

117-
let provider = match self.providers.get(&request.provider_id) {
118-
Some(provider) => provider,
119-
None => return Err(ProviderError::ProviderNotFound(request.provider_id).into()),
113+
let authenticated = {
114+
let session_data = session.data();
115+
let session_data = session_data
116+
.lock()
117+
.map_err(|err| SessionError::Lock(err.to_string()))?;
118+
119+
session_data.authentication.clone()
120120
};
121121

122-
let response = provider.sign_out(request, session.clone()).await?;
122+
let response = if let Some(authenticated) = authenticated {
123+
let provider = match self.providers.get(&authenticated.provider_id) {
124+
Some(provider) => provider,
125+
None => {
126+
return Err(ProviderError::ProviderNotFound(authenticated.provider_id).into())
127+
}
128+
};
129+
130+
provider
131+
.sign_out(
132+
SignOutRequest {
133+
provider_id: authenticated.provider_id,
134+
subprovider_id: authenticated.subprovider_id,
135+
},
136+
session.clone(),
137+
)
138+
.await?
139+
} else {
140+
// TODO: Should be configurable.
141+
Response::Redirect("/".to_owned())
142+
};
123143

124144
session.purge().await?;
125145

packages/core/shield/src/shield_dyn.rs

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use async_trait::async_trait;
55
use crate::{
66
error::ShieldError,
77
provider::{Subprovider, SubproviderVisualisation},
8-
request::{SignInCallbackRequest, SignInRequest, SignOutRequest},
8+
request::{SignInCallbackRequest, SignInRequest},
99
response::Response,
1010
session::Session,
1111
shield::Shield,
@@ -32,11 +32,7 @@ pub trait DynShield: Send + Sync {
3232
session: Session,
3333
) -> Result<Response, ShieldError>;
3434

35-
async fn sign_out(
36-
&self,
37-
request: SignOutRequest,
38-
session: Session,
39-
) -> Result<Response, ShieldError>;
35+
async fn sign_out(&self, session: Session) -> Result<Response, ShieldError>;
4036
}
4137

4238
#[async_trait]
@@ -67,12 +63,8 @@ impl<U: User> DynShield for Shield<U> {
6763
self.sign_in_callback(request, session).await
6864
}
6965

70-
async fn sign_out(
71-
&self,
72-
request: SignOutRequest,
73-
session: Session,
74-
) -> Result<Response, ShieldError> {
75-
self.sign_out(request, session).await
66+
async fn sign_out(&self, session: Session) -> Result<Response, ShieldError> {
67+
self.sign_out(session).await
7668
}
7769
}
7870

@@ -109,11 +101,7 @@ impl ShieldDyn {
109101
self.0.sign_in_callback(request, session).await
110102
}
111103

112-
pub async fn sign_out(
113-
&self,
114-
request: SignOutRequest,
115-
session: Session,
116-
) -> Result<Response, ShieldError> {
117-
self.0.sign_out(request, session).await
104+
pub async fn sign_out(&self, session: Session) -> Result<Response, ShieldError> {
105+
self.0.sign_out(session).await
118106
}
119107
}

packages/integrations/shield-leptos/src/routes/sign_out.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
use leptos::prelude::*;
22

33
#[server]
4-
pub async fn sign_out(
5-
provider_id: String,
6-
subprovider_id: Option<String>,
7-
) -> Result<(), ServerFnError> {
8-
use shield::{Response, ShieldError, SignOutRequest};
4+
pub async fn sign_out() -> Result<(), ServerFnError> {
5+
use shield::{Response, ShieldError};
96

107
use crate::context::expect_server_integration;
118

@@ -14,13 +11,7 @@ pub async fn sign_out(
1411
let session = server_integration.extract_session().await;
1512

1613
let response = shield
17-
.sign_out(
18-
SignOutRequest {
19-
provider_id,
20-
subprovider_id,
21-
},
22-
session,
23-
)
14+
.sign_out(session)
2415
.await
2516
.map_err(ServerFnError::<ShieldError>::from)?;
2617

@@ -41,9 +32,6 @@ pub fn SignOut() -> impl IntoView {
4132
<h1>"Sign out"</h1>
4233

4334
<ActionForm action=sign_out>
44-
// <input name="provider_id" type="hidden" value=subprovider.provider_id />
45-
// <input name="subprovider_id" type="hidden" value=subprovider.subprovider_id />
46-
4735
<button type="submit">"Sign out"</button>
4836
</ActionForm>
4937
}

packages/integrations/shield-tower/src/service.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,15 @@ where
7575
};
7676
let shield_session = Session::new(session_storage);
7777

78-
let user_id = match shield_session.data().lock() {
79-
Ok(session) => session.user_id.clone(),
78+
let authenticated = match shield_session.data().lock() {
79+
Ok(session) => session.authentication.clone(),
8080
Err(_err) => return Ok(Self::internal_server_error()),
8181
};
8282

83-
let user = if let Some(user_id) = user_id {
84-
match shield.storage().user_by_id(&user_id).await {
83+
let user = if let Some(authenticated) = authenticated {
84+
// TODO: Verify provider and subprovider still exist.
85+
86+
match shield.storage().user_by_id(&authenticated.user_id).await {
8587
Ok(user) => {
8688
if user.is_none() {
8789
if let Err(_err) = shield_session.purge().await {

packages/providers/shield-oidc/src/provider.rs

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ use openidconnect::{
77
PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse, UserInfoClaims,
88
};
99
use shield::{
10-
ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError, Response, Session,
11-
SessionError, ShieldError, SignInCallbackRequest, SignInRequest, SignOutRequest, Subprovider,
12-
UpdateUser, User,
10+
Authentication, ConfigurationError, CreateEmailAddress, CreateUser, Provider, ProviderError,
11+
Response, Session, SessionError, ShieldError, SignInCallbackRequest, SignInRequest,
12+
SignOutRequest, Subprovider, UpdateUser, User,
1313
};
1414
use tracing::debug;
1515

@@ -341,7 +341,7 @@ impl<U: User> Provider for OidcProvider<U> {
341341

342342
let connection = self
343343
.create_oidc_connection(
344-
subprovider.id,
344+
subprovider.id.clone(),
345345
user.id(),
346346
claims.subject().to_string(),
347347
token_response,
@@ -352,6 +352,8 @@ impl<U: User> Provider for OidcProvider<U> {
352352
}
353353
};
354354

355+
debug!("signed in {:?} {:?}", user.id(), connection);
356+
355357
session.renew().await?;
356358

357359
{
@@ -360,13 +362,20 @@ impl<U: User> Provider for OidcProvider<U> {
360362
.lock()
361363
.map_err(|err| SessionError::Lock(err.to_string()))?;
362364

363-
session_data.user_id = Some(user.id());
365+
session_data.csrf = None;
366+
session_data.nonce = None;
367+
session_data.verifier = None;
368+
369+
session_data.authentication = Some(Authentication {
370+
provider_id: self.id(),
371+
subprovider_id: Some(subprovider.id),
372+
user_id: user.id(),
373+
});
374+
session_data.oidc_connection_id = Some(connection.id);
364375
}
365376

366377
session.update().await?;
367378

368-
debug!("signed in {:?} {:?}", user.id(), connection);
369-
370379
// TODO: Should be configurable.
371380
Ok(Response::Redirect("/".to_owned()))
372381
}
@@ -381,25 +390,37 @@ impl<U: User> Provider for OidcProvider<U> {
381390
None => return Err(ProviderError::SubproviderMissing.into()),
382391
};
383392

384-
// TODO: find access token
385-
let token = AccessToken::new("".to_owned());
386-
387-
let client = subprovider.oidc_client().await?;
393+
let connection_id = {
394+
let session_data = session.data();
395+
let session_data = session_data
396+
.lock()
397+
.map_err(|err| SessionError::Lock(err.to_string()))?;
388398

389-
let revocation_request = match client.revoke_token(token.into()) {
390-
Ok(revocation_request) => Some(revocation_request),
391-
Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => None,
392-
Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()),
399+
session_data.oidc_connection_id.clone()
393400
};
394401

395-
if let Some(revocation_request) = revocation_request {
396-
revocation_request
397-
.request_async(async_http_client)
398-
.await
399-
.expect("TODO: revocation request error");
400-
}
402+
if let Some(connection_id) = connection_id {
403+
if let Some(connection) = self.storage.oidc_connection_by_id(&connection_id).await? {
404+
debug!("revoking access token {:?}", connection.access_token);
405+
406+
let token = AccessToken::new(connection.access_token);
401407

402-
session.purge().await?;
408+
let client = subprovider.oidc_client().await?;
409+
410+
let revocation_request = match client.revoke_token(token.into()) {
411+
Ok(revocation_request) => Some(revocation_request),
412+
Err(openidconnect::ConfigurationError::MissingUrl("revocation")) => None,
413+
Err(err) => return Err(ConfigurationError::Invalid(err.to_string()).into()),
414+
};
415+
416+
if let Some(revocation_request) = revocation_request {
417+
revocation_request
418+
.request_async(async_http_client)
419+
.await
420+
.map_err(|err| ShieldError::Request(err.to_string()))?;
421+
}
422+
}
423+
}
403424

404425
// TODO: Should be configurable.
405426
Ok(Response::Redirect("/".to_owned()))

packages/providers/shield-oidc/src/storage.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ pub trait OidcStorage<U: User>: Storage<U> + Sync {
1616
subprovider_id: &str,
1717
) -> Result<Option<OidcSubprovider>, StorageError>;
1818

19+
async fn oidc_connection_by_id(
20+
&self,
21+
connection_id: &str,
22+
) -> Result<Option<OidcConnection>, StorageError>;
23+
1924
async fn oidc_connection_by_identifier(
2025
&self,
2126
subprovider_id: &str,

packages/storage/shield-memory/src/providers/oidc.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ impl OidcStorage<User> for MemoryStorage {
2727
Ok(None)
2828
}
2929

30+
async fn oidc_connection_by_id(
31+
&self,
32+
connection_id: &str,
33+
) -> Result<Option<OidcConnection>, StorageError> {
34+
Ok(self
35+
.oidc
36+
.connections
37+
.lock()
38+
.map_err(|err| StorageError::Engine(err.to_string()))?
39+
.iter()
40+
.find(|connection| connection.id == connection_id)
41+
.cloned())
42+
}
43+
3044
async fn oidc_connection_by_identifier(
3145
&self,
3246
subprovider_id: &str,

0 commit comments

Comments
 (0)