diff --git a/crates/mad/Cargo.toml b/crates/mad/Cargo.toml index 81f5a57..d55768b 100644 --- a/crates/mad/Cargo.toml +++ b/crates/mad/Cargo.toml @@ -4,8 +4,8 @@ version.workspace = true description = "notmad is a life-cycle manager for long running rust operations" license = "MIT" repository = "https://github.com/kjuulh/mad" -author = "kjuulh" -edition = "2021" +authors = ["kjuulh"] +edition = "2024" [dependencies] anyhow.workspace = true diff --git a/crates/mad/src/lib.rs b/crates/mad/src/lib.rs index 4b4a4d6..aa2478f 100644 --- a/crates/mad/src/lib.rs +++ b/crates/mad/src/lib.rs @@ -1,10 +1,14 @@ use futures::stream::FuturesUnordered; use futures_util::StreamExt; use std::{fmt::Display, sync::Arc}; -use tokio::signal::unix::{signal, SignalKind}; +use tokio::signal::unix::{SignalKind, signal}; use tokio_util::sync::CancellationToken; +use crate::waiter::Waiter; + +mod waiter; + #[derive(thiserror::Error, Debug)] pub enum MadError { #[error("component failed: {0}")] @@ -70,6 +74,17 @@ impl Mad { self } + pub fn add_conditional(&mut self, condition: bool, component: impl IntoComponent) -> &mut Self { + if condition { + self.components.push(component.into_component()); + } else { + self.components + .push(Waiter::new(component.into_component()).into_component()) + } + + self + } + pub fn add_fn(&mut self, f: F) -> &mut Self where F: Fn(CancellationToken) -> Fut + Send + Sync + 'static, @@ -100,7 +115,7 @@ impl Mad { (Err(run), Err(close)) => { return Err(MadError::AggregateError(AggregateError { errors: vec![run, close], - })) + })); } (Ok(_), Ok(_)) => {} (Ok(_), Err(close)) => return Err(close), diff --git a/crates/mad/src/waiter.rs b/crates/mad/src/waiter.rs new file mode 100644 index 0000000..982ffb6 --- /dev/null +++ b/crates/mad/src/waiter.rs @@ -0,0 +1,32 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use tokio_util::sync::CancellationToken; + +use crate::{Component, MadError}; + +pub struct Waiter { + comp: Arc, +} + +impl Waiter { + pub fn new(c: Arc) -> Self { + Self { comp: c } + } +} + +#[async_trait] +impl Component for Waiter { + fn name(&self) -> Option { + match self.comp.name() { + Some(name) => Some(format!("waiter/{name}")), + None => Some("waiter".into()), + } + } + + async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> { + cancellation_token.cancelled().await; + + Ok(()) + } +}