como/como_auth/src/oauth.rs

251 lines
6.5 KiB
Rust
Raw Normal View History

use async_trait::async_trait;
use oauth2::{basic::BasicClient, AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use std::ops::Deref;
use std::sync::Arc;
#[derive(Clone, clap::Args, Debug)]
pub struct OAuthClientClap {
#[clap(flatten)]
zitadel: ZitadelClap,
#[clap(flatten)]
noop: NoopConfig,
}
#[derive(Clone, clap::Args, Debug)]
pub struct NoopConfig {
#[arg(env = "OAUTH_NOOP", long = "oauth-noop", group = "auth", global = true)]
pub oauth_noop: Option<bool>,
}
#[derive(clap::Args, Clone, Debug, PartialEq, Eq)]
pub struct ZitadelClap {
#[arg(
env = "ZITADEL_AUTH_URL",
long = "zitadel-auth-url",
group = "auth",
global = true
)]
pub auth_url: Option<String>,
#[arg(env = "ZITADEL_CLIENT_ID", long = "zitadel-client-id", global = true)]
pub client_id: Option<String>,
#[arg(
env = "ZITADEL_CLIENT_SECRET",
long = "zitadel-client-secret",
global = true
)]
pub client_secret: Option<String>,
#[arg(
env = "ZITADEL_REDIRECT_URL",
long = "zitadel-redirect-url",
global = true
)]
pub redirect_url: Option<String>,
#[arg(env = "ZITADEL_TOKEN_URL", long = "zitadel-token-url", global = true)]
pub token_url: Option<String>,
}
#[async_trait]
pub trait OAuthClient {
async fn get_token(&self) -> anyhow::Result<()>;
}
pub struct OAuth(Arc<dyn OAuthClient + Send + Sync + 'static>);
impl OAuth {
pub fn new_zitadel(config: ZitadelConfig) -> Self {
Self(Arc::new(ZitadelOAuthClient::from(config)))
}
pub fn new_noop() -> Self {
Self(Arc::new(NoopOAuthClient {}))
}
}
#[derive(Clone)]
pub enum OAuthConfig {
Zitadel(ZitadelConfig),
Noop,
}
impl From<OAuthConfig> for OAuth {
fn from(value: OAuthConfig) -> Self {
match value {
OAuthConfig::Zitadel(c) => c.into(),
OAuthConfig::Noop => Self::new_noop(),
}
}
}
impl Deref for OAuth {
type Target = Arc<dyn OAuthClient + Send + Sync + 'static>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<ZitadelConfig> for OAuth {
fn from(value: ZitadelConfig) -> Self {
Self::new_zitadel(value)
}
}
// -- Noop
#[derive(clap::Args, Clone)]
pub struct NoopOAuthClient;
#[async_trait]
impl OAuthClient for NoopOAuthClient {
async fn get_token(&self) -> anyhow::Result<()> {
Ok(())
}
}
// -- Zitadel
#[derive(clap::Args, Clone)]
#[group(conflicts_with = "NoopConfig", required = false)]
pub struct ZitadelConfig {
#[clap(env = "ZITADEL_AUTH_URL", long = "zitadel-auth-url")]
auth_url: String,
#[clap(env = "ZITADEL_CLIENT_ID", long = "zitadel-client-id")]
client_id: String,
#[clap(env = "ZITADEL_CLIENT_SECRET", long = "zitadel-client-secret")]
client_secret: String,
#[clap(env = "ZITADEL_REDIRECT_URL", long = "zitadel-redirect-url")]
redirect_url: String,
#[clap(env = "ZITADEL_TOKEN_URL", long = "zitadel-token-url")]
token_url: String,
}
pub struct ZitadelOAuthClient {
client: BasicClient,
}
impl ZitadelOAuthClient {
pub fn new(
client_id: impl Into<String>,
client_secret: impl Into<String>,
redirect_url: impl Into<String>,
auth_url: impl Into<String>,
token_url: impl Into<String>,
) -> Self {
Self {
client: Self::oauth_client(ZitadelConfig {
client_id: client_id.into(),
client_secret: client_secret.into(),
redirect_url: redirect_url.into(),
auth_url: auth_url.into(),
token_url: token_url.into(),
}),
}
}
fn oauth_client(config: ZitadelConfig) -> BasicClient {
BasicClient::new(
ClientId::new(config.client_id),
Some(ClientSecret::new(config.client_secret)),
AuthUrl::new(config.auth_url).unwrap(),
Some(TokenUrl::new(config.token_url).unwrap()),
)
.set_redirect_uri(RedirectUrl::new(config.redirect_url).unwrap())
}
}
impl From<ZitadelConfig> for ZitadelOAuthClient {
fn from(value: ZitadelConfig) -> Self {
Self::new(
value.client_id,
value.client_secret,
value.redirect_url,
value.auth_url,
value.token_url,
)
}
}
#[async_trait]
impl OAuthClient for ZitadelOAuthClient {
async fn get_token(&self) -> anyhow::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::oauth::{OAuth, OAuthClientClap, OAuthConfig, ZitadelClap, ZitadelConfig};
use clap::Parser;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
#[clap(flatten)]
options: OAuthClientClap,
#[command(subcommand)]
command: Commands,
}
#[derive(clap::Subcommand, Clone)]
pub enum Commands {
One,
}
#[tokio::test]
async fn test_noop() {
OAuth::from(OAuthConfig::Noop).get_token().await.unwrap();
}
#[tokio::test]
async fn test_zitadel() {
OAuth::from(OAuthConfig::Zitadel(ZitadelConfig {
client_id: "something".into(),
client_secret: "something".into(),
redirect_url: "https://something".into(),
auth_url: "https://something".into(),
token_url: "https://something".into(),
}))
.get_token()
.await
.unwrap();
}
#[tokio::test]
async fn test_parse_clap_noop() {
let cli: Cli = Cli::parse_from(&["base", "one", "--oauth-noop=true"]);
assert_eq!(cli.options.noop.oauth_noop, Some(true));
println!("{:?}", cli.options);
}
#[tokio::test]
async fn test_parse_clap_zitadel() {
let cli: Cli = Cli::parse_from(&[
"base",
"one",
"--zitadel-client-id=something",
"--zitadel-client-secret=something",
"--zitadel-auth-url=https://something",
"--zitadel-redirect-url=https://something",
"--zitadel-token-url=https://something",
]);
pretty_assertions::assert_eq!(
cli.options.zitadel,
ZitadelClap {
auth_url: Some("https://something".into()),
client_id: Some("something".into()),
client_secret: Some("something".into()),
redirect_url: Some("https://something".into()),
token_url: Some("https://something".into())
}
);
println!("{:?}", cli.options);
}
}