diff --git a/Cargo.lock b/Cargo.lock index dc17b04..1eba3c2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -236,7 +236,7 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "mad" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", "async-trait", @@ -247,6 +247,7 @@ dependencies = [ "tokio", "tokio-util", "tracing", + "tracing-subscriber", "tracing-test", ] diff --git a/crates/mad/Cargo.toml b/crates/mad/Cargo.toml index cafaae0..543ab64 100644 --- a/crates/mad/Cargo.toml +++ b/crates/mad/Cargo.toml @@ -15,4 +15,5 @@ tokio-util = "0.7.11" tracing.workspace = true [dev-dependencies] +tracing-subscriber = "0.3.18" tracing-test = { version = "0.2.5", features = ["no-env-filter"] } diff --git a/crates/mad/examples/basic/main.rs b/crates/mad/examples/basic/main.rs new file mode 100644 index 0000000..8f7bce0 --- /dev/null +++ b/crates/mad/examples/basic/main.rs @@ -0,0 +1,40 @@ +use async_trait::async_trait; +use rand::Rng; +use tokio_util::sync::CancellationToken; +use tracing::Level; + +struct WaitServer {} +#[async_trait] +impl mad::Component for WaitServer { + fn name(&self) -> Option { + Some("WaitServer".into()) + } + + async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> { + let millis_wait = rand::thread_rng().gen_range(500..3000); + + tracing::debug!("waiting: {}ms", millis_wait); + + // Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely + tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await; + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(Level::TRACE) + .init(); + + mad::Mad::builder() + .add(WaitServer {}) + .add(WaitServer {}) + .add(WaitServer {}) + .add(WaitServer {}) + .run() + .await?; + + Ok(()) +} diff --git a/crates/mad/examples/error_log/main.rs b/crates/mad/examples/error_log/main.rs new file mode 100644 index 0000000..47da2d5 --- /dev/null +++ b/crates/mad/examples/error_log/main.rs @@ -0,0 +1,42 @@ +use async_trait::async_trait; +use rand::Rng; +use tokio_util::sync::CancellationToken; +use tracing::Level; + +struct ErrorServer {} +#[async_trait] +impl mad::Component for ErrorServer { + fn name(&self) -> Option { + Some("ErrorServer".into()) + } + + async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> { + let millis_wait = rand::thread_rng().gen_range(500..3000); + + tracing::debug!("waiting: {}ms", millis_wait); + + // Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely + tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await; + + Err(mad::MadError::Inner(anyhow::anyhow!("expected error"))) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(Level::TRACE) + .init(); + + // Do note that only the first server which returns an error is guaranteed to be handled. This is because if servers don't respect cancellation, they will be dropped + + mad::Mad::builder() + .add(ErrorServer {}) + .add(ErrorServer {}) + .add(ErrorServer {}) + .add(ErrorServer {}) + .run() + .await?; + + Ok(()) +} diff --git a/crates/mad/examples/fn/main.rs b/crates/mad/examples/fn/main.rs new file mode 100644 index 0000000..02bba2e --- /dev/null +++ b/crates/mad/examples/fn/main.rs @@ -0,0 +1,47 @@ +use async_trait::async_trait; +use rand::Rng; +use tokio_util::sync::CancellationToken; +use tracing::Level; + +struct WaitServer {} +#[async_trait] +impl mad::Component for WaitServer { + fn name(&self) -> Option { + Some("WaitServer".into()) + } + + async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> { + let millis_wait = rand::thread_rng().gen_range(500..3000); + + tracing::debug!("waiting: {}ms", millis_wait); + + // Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely + tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await; + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(Level::TRACE) + .init(); + + mad::Mad::builder() + .add(WaitServer {}) + .add_fn(|cancel| async move { + let millis_wait = 50; + + tracing::debug!("waiting: {}ms", millis_wait); + + // Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely + tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await; + + Ok(()) + }) + .run() + .await?; + + Ok(()) +} diff --git a/crates/mad/examples/signals/main.rs b/crates/mad/examples/signals/main.rs new file mode 100644 index 0000000..21594c6 --- /dev/null +++ b/crates/mad/examples/signals/main.rs @@ -0,0 +1,69 @@ +use async_trait::async_trait; +use rand::Rng; +use tokio_util::sync::CancellationToken; +use tracing::Level; + +struct WaitServer {} +#[async_trait] +impl mad::Component for WaitServer { + fn name(&self) -> Option { + Some("WaitServer".into()) + } + + async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> { + let millis_wait = rand::thread_rng().gen_range(500..3000); + + tracing::debug!("waiting: {}ms", millis_wait); + + // Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely + tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await; + + Ok(()) + } +} + +struct RespectCancel {} +#[async_trait] +impl mad::Component for RespectCancel { + fn name(&self) -> Option { + Some("RespectCancel".into()) + } + + async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> { + cancellation.cancelled().await; + tracing::debug!("stopping because job is cancelled"); + + Ok(()) + } +} + +struct NeverStopServer {} +#[async_trait] +impl mad::Component for NeverStopServer { + fn name(&self) -> Option { + Some("NeverStopServer".into()) + } + + async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> { + // Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely + tokio::time::sleep(std::time::Duration::from_millis(999999999)).await; + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(Level::TRACE) + .init(); + + mad::Mad::builder() + .add(WaitServer {}) + .add(NeverStopServer {}) + .add(RespectCancel {}) + .run() + .await?; + + Ok(()) +} diff --git a/crates/mad/src/lib.rs b/crates/mad/src/lib.rs index 4a5c3f0..0f57fcf 100644 --- a/crates/mad/src/lib.rs +++ b/crates/mad/src/lib.rs @@ -49,6 +49,11 @@ pub struct Mad { should_cancel: Option, } +struct CompletionResult { + res: Result<(), MadError>, + name: Option, +} + impl Mad { pub fn builder() -> Self { Self { @@ -64,6 +69,16 @@ impl Mad { self } + pub fn add_fn(&mut self, f: F) -> &mut Self + where + F: Fn(CancellationToken) -> Fut + Send + Sync + 'static, + Fut: futures::Future> + Send + 'static, + { + let comp = ClosureComponent { inner: Box::new(f) }; + + self.add(comp) + } + pub fn cancellation(&mut self, should_cancel: Option) -> &mut Self { self.should_cancel = should_cancel; @@ -122,22 +137,23 @@ impl Mad { let cancellation_token = cancellation_token.child_token(); let job_cancellation = job_cancellation.child_token(); - let (error_tx, error_rx) = tokio::sync::mpsc::channel::>(1); + let (error_tx, error_rx) = tokio::sync::mpsc::channel::(1); channels.push(error_rx); tokio::spawn(async move { - tracing::debug!(component = &comp.name(), "mad running"); + let name = comp.name().clone(); + + tracing::debug!(component = name, "mad running"); tokio::select! { _ = cancellation_token.cancelled() => { - error_tx.send(Ok(())).await + error_tx.send(CompletionResult { res: Ok(()) , name }).await } res = comp.run(job_cancellation) => { - error_tx.send(res).await - + error_tx.send(CompletionResult { res , name }).await } _ = tokio::signal::ctrl_c() => { - error_tx.send(Ok(())).await + error_tx.send(CompletionResult { res: Ok(()) , name }).await } } }); @@ -149,17 +165,24 @@ impl Mad { } while let Some(Some(msg)) = futures.next().await { - tracing::trace!("received end signal from a component"); - - if let Err(e) = msg { - tracing::debug!(error = e.to_string(), "stopping running components"); - job_cancellation.cancel(); - - if let Some(cancel_wait) = self.should_cancel { - tokio::time::sleep(cancel_wait).await; - - cancellation_token.cancel(); + match msg.res { + Err(e) => { + tracing::debug!( + error = e.to_string(), + component = msg.name, + "component ran to completion with error" + ); } + Ok(_) => { + tracing::debug!(component = msg.name, "component ran to completion"); + } + } + + job_cancellation.cancel(); + if let Some(cancel_wait) = self.should_cancel { + tokio::time::sleep(cancel_wait).await; + + cancellation_token.cancel(); } } @@ -211,3 +234,34 @@ impl IntoComponent for T { Arc::new(self) } } + +struct ClosureComponent +where + F: Fn(CancellationToken) -> Fut + Send + Sync + 'static, + Fut: futures::Future> + Send + 'static, +{ + inner: Box, +} + +impl ClosureComponent +where + F: Fn(CancellationToken) -> Fut + Send + Sync + 'static, + Fut: futures::Future> + Send + 'static, +{ + pub async fn execute(&self, cancellation_token: CancellationToken) -> Result<(), MadError> { + (*self.inner)(cancellation_token).await?; + + Ok(()) + } +} + +#[async_trait::async_trait] +impl Component for ClosureComponent +where + F: Fn(CancellationToken) -> Fut + Send + Sync + 'static, + Fut: futures::Future> + Send + 'static, +{ + async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> { + self.execute(cancellation_token).await + } +}