use std::{net::SocketAddr, str::FromStr}; use axum::{ extract::{FromRef, State}, response::IntoResponse, routing::get, Router, }; use clap::Parser; use nefarious_login::{ auth::AuthService, axum::{AuthController, UserFromSession}, login::AuthClap, }; use tracing_subscriber::EnvFilter; #[derive(Clone)] struct AppState { auth: AuthService, } #[derive(Debug, Clone, Parser)] struct Command { #[clap(flatten)] auth: AuthClap, } #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .init(); let cmd = Command::parse_from(vec![ "base", "--auth-engine=zitadel", "--zitadel-authority-url=https://personal-wxuujs.zitadel.cloud", "--zitadel-redirect-url=http://localhost:3001/auth/authorized", "--zitadel-client-id=237412977047895154@nefarious-test", "--zitadel-client-secret=rWwDi8gjNOyuMFKoOjNSlhjcVZ1B25wDh6HsDL27f0g2Hb0xGbvEf0WXFY2akOlL", "--session-backend=postgresql", "--session-postgres-conn=postgres://nefarious-test:somenotverysecurepassword@localhost:5432/nafarious-test", "--login-return-url=http://localhost:3001/authed" ]); let auth_service = AuthService::new(&cmd.auth).await?; let state = AppState { auth: auth_service.clone(), }; let app = Router::new() .route("/unauthed", get(unauthed)) .route("/authed", get(authed)) .with_state(state) .nest("/auth", AuthController::new_router(auth_service).await?); let addr = SocketAddr::from(([127, 0, 0, 1], 3001)); println!("listening on: {addr}"); println!("open browser at: http://localhost:3001/auth/zitadel"); let addr = SocketAddr::from_str(&format!("{}:{}", "127.0.0.1", "3000"))?; let listener = tokio::net::TcpListener::bind(&addr).await?; axum::serve(listener, app).await?; Ok(()) } impl FromRef for AuthService { fn from_ref(input: &AppState) -> Self { input.auth.clone() } } async fn unauthed() -> String { "Hello Unauthorized User".into() } #[axum::debug_handler()] async fn authed( user: UserFromSession, State(_auth_service): State, ) -> impl IntoResponse { format!("Hello authorized user: {:?}", user.user.id) }