use std::fmt::Display; use crate::router::AppState; use axum::extract::{FromRef, FromRequestParts, Query, State}; use axum::headers::authorization::Basic; use axum::headers::{Authorization, Cookie}; use axum::http::request::Parts; use axum::http::StatusCode; use axum::response::{ErrorResponse, IntoResponse, Redirect}; use axum::routing::get; use axum::{async_trait, Json, RequestPartsExt, Router, TypedHeader}; use como_domain::users::User; use como_infrastructure::register::ServiceRegister; use serde::Deserialize; use serde_json::json; #[derive(Debug, Deserialize)] pub struct ZitadelAuthParams { #[allow(dead_code)] return_url: Option, } trait AnyhowExtensions where E: Display, { fn into_response(self) -> Result; } impl AnyhowExtensions for anyhow::Result where E: Display, { fn into_response(self) -> Result { match self { Ok(o) => Ok(o), Err(e) => { tracing::error!("failed with anyhow error: {}", e); Err(ErrorResponse::from(( StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "status": "something", })), ))) } } } } pub async fn zitadel_auth( State(services): State, ) -> Result { let url = services.auth_service.login().await.into_response()?; Ok(Redirect::to(&url.to_string())) } #[derive(Debug, Deserialize)] #[allow(dead_code)] pub struct AuthRequest { code: String, state: String, } pub async fn login_authorized( Query(query): Query, State(services): State, ) -> Result { let (headers, url) = services .auth_service .login_authorized(&query.code, &query.state) .await .into_response()?; Ok((headers, Redirect::to(url.as_str()))) } pub struct AuthController; impl AuthController { pub async fn new_router( _service_register: ServiceRegister, app_state: AppState, ) -> anyhow::Result { Ok(Router::new() .route("/zitadel", get(zitadel_auth)) .route("/authorized", get(login_authorized)) .with_state(app_state)) } } pub struct UserFromSession { pub user: User, } pub static COOKIE_NAME: &str = "SESSION"; #[async_trait] impl FromRequestParts for UserFromSession where ServiceRegister: FromRef, S: Send + Sync, { type Rejection = (StatusCode, &'static str); async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let services = ServiceRegister::from_ref(state); let cookie: Option> = parts.extract().await.unwrap(); let session_cookie = cookie.as_ref().and_then(|cookie| cookie.get(COOKIE_NAME)); if let None = session_cookie { let basic: Option>> = parts.extract().await.unwrap(); if let Some(basic) = basic { let token = services .auth_service .login_token(basic.username(), basic.password()) .await .into_response() .map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, "could not get token from basic", ) })?; return Ok(UserFromSession { user: User { id: token }, }); } return Err(anyhow::anyhow!("No session was found")) .into_response() .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "did not find a cookie"))?; } let session_cookie = session_cookie.unwrap(); // continue to decode the session cookie let user = services .auth_service .get_user_from_session(session_cookie) .await .into_response() .map_err(|_| { ( StatusCode::INTERNAL_SERVER_ERROR, "failed to decode session cookie", ) })?; Ok(UserFromSession { user: User { id: user.id }, }) } }