diff --git a/crates/nefarious-login/src/auth.rs b/crates/nefarious-login/src/auth.rs index 81b7b0f..8099610 100644 --- a/crates/nefarious-login/src/auth.rs +++ b/crates/nefarious-login/src/auth.rs @@ -9,18 +9,18 @@ use crate::{ introspection::{IdToken, IntrospectionService}, login::{auth_clap::AuthEngine, config::ConfigClap, AuthClap}, oauth::{zitadel::ZitadelConfig, OAuth}, - session::{SessionService, User}, + session::{AppSession, SessionService, User}, }; #[async_trait] pub trait Auth { - async fn login(&self, return_url: Option) -> anyhow::Result; + async fn login(&self, return_url: Option) -> anyhow::Result<(HeaderMap, Url)>; async fn login_token(&self, user: &str, password: &str) -> anyhow::Result; async fn login_authorized( &self, code: &str, state: &str, - return_url: Option, + app_session_cookie: Option, ) -> anyhow::Result<(HeaderMap, Url)>; async fn get_user_from_session(&self, cookie: &str) -> anyhow::Result; } @@ -89,19 +89,31 @@ pub struct ZitadelAuthService { config: ConfigClap, } pub static COOKIE_NAME: &str = "SESSION"; +pub static COOKIE_APP_SESSION_NAME: &str = "APP_SESSION"; #[async_trait] impl Auth for ZitadelAuthService { - async fn login(&self, return_url: Option) -> anyhow::Result { - let authorize_url = self.oauth.authorize_url(return_url).await?; + async fn login(&self, return_url: Option) -> anyhow::Result<(HeaderMap, Url)> { + let mut headers = HeaderMap::new(); + if let Some(return_url) = return_url.clone() { + let cookie_value = self.session.insert(AppSession { return_url }).await?; - Ok(authorize_url) + let cookie = format!( + "{}={}; SameSite=Lax; Path=/", + COOKIE_APP_SESSION_NAME, cookie_value + ); + headers.insert(SET_COOKIE, cookie.parse().unwrap()); + } + + let authorize_url = self.oauth.authorize_url().await?; + + Ok((headers, authorize_url)) } async fn login_authorized( &self, code: &str, _state: &str, - return_path: Option, + app_session_cookie: Option, ) -> anyhow::Result<(HeaderMap, Url)> { let token = self.oauth.exchange(code).await?; let id_token = self.introspection.get_id_token(token.as_str()).await?; @@ -113,8 +125,16 @@ impl Auth for ZitadelAuthService { headers.insert(SET_COOKIE, cookie.parse().unwrap()); let mut return_url = self.config.return_url.clone(); - if let Some(return_path) = return_path { - return_url.push_str(&format!("?returnPath={return_path}")); + if let Some(cookie) = app_session_cookie { + if let Some(session) = self.session.get(&cookie).await? { + if session.return_url.starts_with('/') { + let mut url = Url::parse(&return_url)?; + url.set_path(&session.return_url); + return_url = url.to_string(); + } else { + return_url = session.return_url; + } + } } Ok(( @@ -141,7 +161,7 @@ pub struct NoopAuthService { #[async_trait] impl Auth for NoopAuthService { - async fn login(&self, return_url: Option) -> anyhow::Result { + async fn login(&self, return_url: Option) -> anyhow::Result<(HeaderMap, Url)> { let url = Url::parse(&format!( "{}/auth/authorized?code=noop&state=noop", self.config @@ -151,13 +171,13 @@ impl Auth for NoopAuthService { .unwrap() )) .unwrap(); - Ok(url) + Ok((HeaderMap::new(), url)) } async fn login_authorized( &self, _code: &str, _state: &str, - _return_url: Option, + _app_session_cookie: Option, ) -> anyhow::Result<(HeaderMap, Url)> { let cookie_value = self .session diff --git a/crates/nefarious-login/src/axum.rs b/crates/nefarious-login/src/axum.rs index acb74d6..424f067 100644 --- a/crates/nefarious-login/src/axum.rs +++ b/crates/nefarious-login/src/axum.rs @@ -7,13 +7,14 @@ use axum::response::{ErrorResponse, IntoResponse, Redirect}; use axum::routing::get; use axum::{async_trait, Json, RequestPartsExt, Router}; +use axum_extra::extract::CookieJar; use axum_extra::headers::authorization::Basic; use axum_extra::headers::{Authorization, Cookie}; use axum_extra::TypedHeader; use serde::Deserialize; use serde_json::json; -use crate::auth::AuthService; +use crate::auth::{AuthService, COOKIE_APP_SESSION_NAME}; use crate::session::User; #[derive(Debug, Deserialize)] @@ -50,9 +51,9 @@ where pub async fn zitadel_auth( State(auth_service): State, ) -> Result { - let url = auth_service.login(None).await.into_response()?; + let (headers, url) = auth_service.login(None).await.into_response()?; - Ok(Redirect::to(url.as_ref())) + Ok((headers, Redirect::to(url.as_ref()))) } #[derive(Debug, Deserialize)] @@ -60,16 +61,19 @@ pub async fn zitadel_auth( pub struct AuthRequest { code: String, state: String, - #[serde(alias = "returnUrl")] - return_url: Option, } pub async fn login_authorized( Query(query): Query, State(auth_service): State, + cookie_jar: CookieJar, ) -> Result { + let cookie_value = cookie_jar + .get(COOKIE_APP_SESSION_NAME) + .map(|c| c.value().to_string()); + let (headers, url) = auth_service - .login_authorized(&query.code, &query.state, query.return_url) + .login_authorized(&query.code, &query.state, cookie_value) .await .into_response()?; diff --git a/crates/nefarious-login/src/oauth.rs b/crates/nefarious-login/src/oauth.rs index 27059d8..bc0da8c 100644 --- a/crates/nefarious-login/src/oauth.rs +++ b/crates/nefarious-login/src/oauth.rs @@ -31,7 +31,7 @@ impl Deref for OAuth { #[async_trait] pub trait OAuthClient { async fn get_token(&self) -> anyhow::Result<()>; - async fn authorize_url(&self, return_url: Option) -> anyhow::Result; + async fn authorize_url(&self) -> anyhow::Result; async fn exchange(&self, code: &str) -> anyhow::Result; } diff --git a/crates/nefarious-login/src/oauth/noop.rs b/crates/nefarious-login/src/oauth/noop.rs index 020779d..5d5432a 100644 --- a/crates/nefarious-login/src/oauth/noop.rs +++ b/crates/nefarious-login/src/oauth/noop.rs @@ -10,7 +10,7 @@ impl OAuthClient for NoopOAuthClient { async fn get_token(&self) -> anyhow::Result<()> { Ok(()) } - async fn authorize_url(&self, return_url: Option) -> anyhow::Result { + async fn authorize_url(&self) -> anyhow::Result { Ok(Url::parse("http://localhost:3000/auth/zitadel").unwrap()) } diff --git a/crates/nefarious-login/src/oauth/zitadel.rs b/crates/nefarious-login/src/oauth/zitadel.rs index a144157..5daf5e8 100644 --- a/crates/nefarious-login/src/oauth/zitadel.rs +++ b/crates/nefarious-login/src/oauth/zitadel.rs @@ -104,7 +104,7 @@ impl OAuthClient for ZitadelOAuthClient { async fn get_token(&self) -> anyhow::Result<()> { Ok(()) } - async fn authorize_url(&self, return_url: Option) -> anyhow::Result { + async fn authorize_url(&self) -> anyhow::Result { let req = self .client .authorize_url(CsrfToken::new_random) @@ -113,18 +113,6 @@ impl OAuthClient for ZitadelOAuthClient { .add_scope(Scope::new("email".to_string())) .add_scope(Scope::new("profile".to_string())); - let req = { - if let Some(return_url) = return_url { - let mut redirect_url = self.client.redirect_url().unwrap().as_str().to_string(); - - redirect_url.push_str(&format!("?returnUrl={}", return_url)); - - req.set_redirect_uri(std::borrow::Cow::Owned(RedirectUrl::new(redirect_url)?)) - } else { - req - } - }; - let (auth_url, _csrf_token) = req.url(); Ok(auth_url) diff --git a/crates/nefarious-login/src/session.rs b/crates/nefarious-login/src/session.rs index 8db0793..86083a5 100644 --- a/crates/nefarious-login/src/session.rs +++ b/crates/nefarious-login/src/session.rs @@ -27,12 +27,13 @@ pub struct PostgresqlSessionClap { #[async_trait] pub trait Session { + async fn insert(&self, app_session: AppSession) -> anyhow::Result; async fn insert_user(&self, id: &str, id_token: IdToken) -> anyhow::Result; async fn get_user(&self, cookie: &str) -> anyhow::Result>; + async fn get(&self, cookie: &str) -> anyhow::Result>; } pub struct SessionService(Arc); - impl SessionService { pub async fn new(config: &AuthClap) -> anyhow::Result { match config.session_backend { @@ -77,8 +78,26 @@ pub struct User { pub name: String, } +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct AppSession { + pub return_url: String, +} + #[async_trait] impl Session for PostgresSessionService { + async fn insert(&self, app_session: AppSession) -> anyhow::Result { + let mut session = AxumSession::new(); + session.insert("app_session", app_session)?; + + let cookie = self + .store + .store_session(session) + .await? + .ok_or(anyhow::anyhow!("failed to store app session"))?; + + Ok(cookie) + } + async fn insert_user(&self, _id: &str, id_token: IdToken) -> anyhow::Result { let mut session = AxumSession::new(); session.insert( @@ -117,6 +136,18 @@ impl Session for PostgresSessionService { Err(anyhow::anyhow!("No session found for cookie")) } } + + async fn get(&self, cookie: &str) -> anyhow::Result> { + let Some(session) = self.store.load_session(cookie.to_string()).await? else { + return Ok(None); + }; + + let Some(session) = session.get::("app_session") else { + anyhow::bail!("failed to deserialize app_session from cookie"); + }; + + Ok(Some(session)) + } } #[derive(Default)] @@ -126,6 +157,9 @@ pub struct InMemorySessionService { #[async_trait] impl Session for InMemorySessionService { + async fn insert(&self, app_session: AppSession) -> anyhow::Result { + todo!() + } async fn insert_user(&self, _id: &str, id_token: IdToken) -> anyhow::Result { let user = User { id: id_token.sub, @@ -145,4 +179,7 @@ impl Session for InMemorySessionService { Ok(user) } + async fn get(&self, cookie: &str) -> anyhow::Result> { + todo!() + } } diff --git a/examples/custom_redirect/src/main.rs b/examples/custom_redirect/src/main.rs index a5a13bb..5ad4d07 100644 --- a/examples/custom_redirect/src/main.rs +++ b/examples/custom_redirect/src/main.rs @@ -45,7 +45,7 @@ async fn main() -> anyhow::Result<()> { conn: Some("postgres://nefarious-test:somenotverysecurepassword@localhost:5432/nefarious-test".into()), }, }, - config: ConfigClap { return_url: "http://localhost:3001".into() } // this normally has /authed + config: ConfigClap { return_url: "http://localhost:3001/authed".into() } // this normally has /authed }; let auth_service = AuthService::new(&auth).await?; @@ -76,9 +76,9 @@ impl FromRef for AuthService { } async fn login(State(auth_service): State) -> impl IntoResponse { - let url = auth_service.login(Some("/authed".into())).await.unwrap(); + let (headers, url) = auth_service.login(Some("/authed".into())).await.unwrap(); - Redirect::to(url.as_ref()) + (headers, Redirect::to(url.as_ref())) } async fn unauthed() -> String {