diff --git a/crates/nefarious-login/src/axum.rs b/crates/nefarious-login/src/axum.rs index 424f067..1247cd3 100644 --- a/crates/nefarious-login/src/axum.rs +++ b/crates/nefarious-login/src/axum.rs @@ -2,8 +2,8 @@ use std::fmt::Display; use axum::extract::{FromRef, FromRequestParts, Query, State}; use axum::http::request::Parts; -use axum::http::StatusCode; -use axum::response::{ErrorResponse, IntoResponse, Redirect}; +use axum::http::{HeaderMap, StatusCode, Uri}; +use axum::response::{ErrorResponse, IntoResponse, Redirect, Response}; use axum::routing::get; use axum::{async_trait, Json, RequestPartsExt, Router}; @@ -97,13 +97,21 @@ pub struct UserFromSession { pub static COOKIE_NAME: &str = "SESSION"; +pub struct AuthRedirect((HeaderMap, String)); + +impl IntoResponse for AuthRedirect { + fn into_response(self) -> Response { + (self.0 .0, Redirect::temporary(&self.0 .1.as_str())).into_response() + } +} + #[async_trait] impl FromRequestParts for UserFromSession where AuthService: FromRef, S: Send + Sync, { - type Rejection = (StatusCode, &'static str); + type Rejection = AuthRedirect; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let auth_service = AuthService::from_ref(state); @@ -114,16 +122,21 @@ where let basic: Option>> = parts.extract().await.unwrap(); if let Some(basic) = basic { - let token = auth_service + let token = match auth_service .login_token(basic.username(), basic.password()) .await .into_response() - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "could not get token from basic", - ) - })?; + { + Ok(login) => login, + Err(e) => { + tracing::info!("did not find a basic login token, will trigger login"); + let (headers, url) = auth_service + .login(Some(parts.uri.to_string())) + .await + .expect("to be able to request login"); + return Err(AuthRedirect((headers, url.to_string()))); + } + }; return Ok(UserFromSession { user: User { @@ -134,24 +147,32 @@ where }); } - return Err(anyhow::anyhow!("No session was found")) - .into_response() - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "did not find a cookie"))?; + tracing::info!("did not find a cookie, will trigger login"); + let (headers, url) = auth_service + .login(Some(parts.uri.to_string())) + .await + .expect("to be able to request login"); + return Err(AuthRedirect((headers, url.to_string()))); } let session_cookie = session_cookie.unwrap(); // continue to decode the session cookie - let user = auth_service + let user = match auth_service .get_user_from_session(session_cookie) .await .into_response() - .map_err(|_| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "failed to decode session cookie", - ) - })?; + { + Ok(user) => user, + Err(_) => { + tracing::info!("could not get user from session, will trigger login"); + let (headers, url) = auth_service + .login(Some(parts.uri.to_string())) + .await + .expect("to be able to request login"); + return Err(AuthRedirect((headers, url.to_string()))); + } + }; Ok(UserFromSession { user }) }