use async_trait::async_trait; use oauth2::reqwest::async_http_client; use oauth2::url::Url; use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; use std::ops::Deref; use std::sync::Arc; use crate::ZitadelClap; #[async_trait] pub trait OAuthClient { async fn get_token(&self) -> anyhow::Result<()>; async fn authorize_url(&self) -> anyhow::Result; async fn exchange(&self, code: &str) -> anyhow::Result; } pub struct OAuth(Arc); impl OAuth { pub fn new_zitadel(config: ZitadelConfig) -> Self { Self(Arc::new(ZitadelOAuthClient::from(config))) } pub fn new_noop() -> Self { Self(Arc::new(NoopOAuthClient {})) } } impl Deref for OAuth { type Target = Arc; fn deref(&self) -> &Self::Target { &self.0 } } impl From for OAuth { fn from(value: ZitadelConfig) -> Self { Self::new_zitadel(value) } } // -- Noop #[derive(clap::Args, Clone)] pub struct NoopOAuthClient; #[async_trait] impl OAuthClient for NoopOAuthClient { async fn get_token(&self) -> anyhow::Result<()> { Ok(()) } async fn authorize_url(&self) -> anyhow::Result { Ok(Url::parse("http://localhost:3000/auth/zitadel").unwrap()) } async fn exchange(&self, _code: &str) -> anyhow::Result { Ok(String::new()) } } // -- Zitadel #[derive(Clone)] pub struct ZitadelConfig { auth_url: String, client_id: String, client_secret: String, redirect_url: String, token_url: String, authority_url: String, } pub struct ZitadelOAuthClient { client: BasicClient, } impl ZitadelOAuthClient { pub fn new( client_id: impl Into, client_secret: impl Into, redirect_url: impl Into, auth_url: impl Into, token_url: impl Into, authority_url: impl Into, ) -> Self { Self { client: Self::oauth_client(ZitadelConfig { client_id: client_id.into(), client_secret: client_secret.into(), redirect_url: redirect_url.into(), auth_url: auth_url.into(), token_url: token_url.into(), authority_url: authority_url.into(), }), } } fn oauth_client(config: ZitadelConfig) -> BasicClient { BasicClient::new( ClientId::new(config.client_id), Some(ClientSecret::new(config.client_secret)), AuthUrl::new(config.auth_url).unwrap(), Some(TokenUrl::new(config.token_url).unwrap()), ) .set_redirect_uri(RedirectUrl::new(config.redirect_url).unwrap()) } } impl From for ZitadelOAuthClient { fn from(value: ZitadelConfig) -> Self { Self::new( value.client_id, value.client_secret, value.redirect_url, value.auth_url, value.token_url, value.authority_url, ) } } impl TryFrom for ZitadelConfig { type Error = anyhow::Error; fn try_from(value: ZitadelClap) -> Result { Ok(Self { auth_url: value .auth_url .ok_or(anyhow::anyhow!("auth_url was not set"))?, client_id: value .client_id .ok_or(anyhow::anyhow!("client_id was not set"))?, client_secret: value .client_secret .ok_or(anyhow::anyhow!("client_secret was not set"))?, redirect_url: value .redirect_url .ok_or(anyhow::anyhow!("redirect_url was not set"))?, token_url: value .token_url .ok_or(anyhow::anyhow!("token_url was not set"))?, authority_url: value .authority_url .ok_or(anyhow::anyhow!("authority_url was not set"))?, }) } } #[async_trait] impl OAuthClient for ZitadelOAuthClient { async fn get_token(&self) -> anyhow::Result<()> { Ok(()) } async fn authorize_url(&self) -> anyhow::Result { let (auth_url, _csrf_token) = self .client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) .add_scope(Scope::new("openid".to_string())) .url(); Ok(auth_url) } async fn exchange(&self, code: &str) -> anyhow::Result { let token = self .client .exchange_code(AuthorizationCode::new(code.to_string())) .request_async(async_http_client) .await?; Ok(token.access_token().secret().clone()) } } #[cfg(test)] mod tests { use crate::ZitadelClap; use clap::Parser; use sealed_test::prelude::*; #[derive(Parser)] #[command(author, version, about, long_about = None)] pub struct Cli { #[clap(flatten)] options: ZitadelClap, } #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct CliSubCommand { #[command(subcommand)] command: Commands, } #[derive(clap::Subcommand, Clone, Debug, Eq, PartialEq)] pub enum Commands { One { #[clap(flatten)] options: ZitadelClap, }, } #[tokio::test] async fn test_parse_clap_zitadel() { let cli: Cli = Cli::parse_from(&[ "base", "--zitadel-client-id=something", "--zitadel-client-secret=something", "--zitadel-auth-url=https://something", "--zitadel-redirect-url=https://something", "--zitadel-token-url=https://something", "--zitadel-authority-url=https://something", ]); println!("{:?}", cli.options); pretty_assertions::assert_eq!( cli.options, ZitadelClap { auth_url: Some("https://something".into()), client_id: Some("something".into()), client_secret: Some("something".into()), redirect_url: Some("https://something".into()), token_url: Some("https://something".into()), authority_url: Some("https://something".into()), } ); } #[test] fn test_parse_clap_zitadel_fails_require_all() { let cli = CliSubCommand::try_parse_from(&[ "base", "one", // "--zitadel-client-id=something", // We want to trigger missing variable "--zitadel-client-secret=something", "--zitadel-auth-url=https://something", "--zitadel-redirect-url=https://something", "--zitadel-token-url=https://something", "--zitadel-authority-url=https://something", ]); pretty_assertions::assert_eq!(cli.is_err(), true); } #[sealed_test] fn test_parse_clap_env_zitadel() { std::env::set_var("ZITADEL_CLIENT_ID", "something"); std::env::set_var("ZITADEL_CLIENT_SECRET", "something"); std::env::set_var("ZITADEL_AUTH_URL", "https://something"); std::env::set_var("ZITADEL_REDIRECT_URL", "https://something"); std::env::set_var("ZITADEL_TOKEN_URL", "https://something"); std::env::set_var("ZITADEL_AUTHORITY_URL", "https://something"); let cli = CliSubCommand::parse_from(&["base", "one"]); pretty_assertions::assert_eq!( cli.command, Commands::One { options: ZitadelClap { auth_url: Some("https://something".into()), client_id: Some("something".into()), client_secret: Some("something".into()), redirect_url: Some("https://something".into()), token_url: Some("https://something".into()), authority_url: Some("https://something".into()), } } ); } #[test] fn test_parse_clap_defaults_to_noop() { let cli = CliSubCommand::parse_from(&["base", "one"]); pretty_assertions::assert_eq!( cli.command, Commands::One { options: ZitadelClap { auth_url: None, client_id: None, client_secret: None, redirect_url: None, token_url: None, authority_url: None, }, } ); } }