use std::{net::SocketAddr, str::FromStr}; use axum::{ extract::{FromRef, State}, response::IntoResponse, routing::get, Router, }; use nefarious_login::{ auth::AuthService, axum::{AuthController, UserFromSession}, login::{ auth_clap::{AuthEngine, ZitadelClap}, config::ConfigClap, AuthClap, }, session::{PostgresqlSessionClap, SessionBackend}, }; use tracing_subscriber::EnvFilter; #[derive(Clone)] struct AppState { auth: AuthService, } #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .init(); let auth = AuthClap { engine: AuthEngine::Zitadel, session_backend: SessionBackend::Postgresql, zitadel: ZitadelClap { authority_url: Some("https://personal-wxuujs.zitadel.cloud".into()), client_id: Some("237412977047895154@nefarious-test".into()), client_secret: Some( "rWwDi8gjNOyuMFKoOjNSlhjcVZ1B25wDh6HsDL27f0g2Hb0xGbvEf0WXFY2akOlL".into(), ), redirect_url: Some("http://localhost:3001/auth/authorized".into()), }, session: nefarious_login::session::SessionClap { postgresql: PostgresqlSessionClap { conn: Some("postgres://nefarious-test:somenotverysecurepassword@localhost:5432/nefarious-test".into()), }, }, config: ConfigClap { return_url: "http://localhost:3001/authed".into() } }; let auth_service = AuthService::new(&auth).await?; let state = AppState { auth: auth_service.clone(), }; let app = Router::new() .route("/unauthed", get(unauthed)) .route("/authed", get(authed)) .with_state(state) .nest("/auth", AuthController::new_router(auth_service).await?); let addr = SocketAddr::from_str(&format!("{}:{}", "127.0.0.1", "3000"))?; let listener = tokio::net::TcpListener::bind(&addr).await?; axum::serve(listener, app).await?; Ok(()) } impl FromRef for AuthService { fn from_ref(input: &AppState) -> Self { input.auth.clone() } } async fn unauthed() -> String { "Hello Unauthorized User".into() } #[axum::debug_handler()] async fn authed( user: UserFromSession, State(_auth_service): State, ) -> impl IntoResponse { format!("Hello authorized user: {:?}", user.user.id) }