diff --git a/Cargo.lock b/Cargo.lock index 86ef8a2..323893b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -266,7 +266,7 @@ dependencies = [ [[package]] name = "notmad" -version = "0.4.0" +version = "0.5.0" dependencies = [ "anyhow", "async-trait", diff --git a/crates/mad/src/lib.rs b/crates/mad/src/lib.rs index dd4d1ea..b73c7a4 100644 --- a/crates/mad/src/lib.rs +++ b/crates/mad/src/lib.rs @@ -154,16 +154,38 @@ impl Mad { res = comp.run(job_cancellation) => { error_tx.send(CompletionResult { res , name }).await } - _ = tokio::signal::ctrl_c() => { - error_tx.send(CompletionResult { res: Ok(()) , name }).await - } - _ = signal_unix_terminate() => { - error_tx.send(CompletionResult { res: Ok(()) , name }).await - } } }); } + tokio::spawn({ + let cancellation_token = cancellation_token.child_token(); + let wait_cancel = self.should_cancel; + + async move { + let should_cancel = + |cancel: CancellationToken, wait: Option| async move { + if let Some(cancel_wait) = wait { + tokio::time::sleep(cancel_wait).await; + + cancel.cancel(); + } + }; + + tokio::select! { + _ = cancellation_token.cancelled() => { + job_cancellation.cancel(); + } + _ = tokio::signal::ctrl_c() => { + should_cancel(job_cancellation, wait_cancel).await; + } + _ = signal_unix_terminate() => { + should_cancel(job_cancellation, wait_cancel).await; + } + } + } + }); + let mut futures = FuturesUnordered::new(); for channel in channels.iter_mut() { futures.push(channel.recv()); @@ -182,13 +204,6 @@ impl Mad { tracing::debug!(component = msg.name, "component ran to completion"); } } - - job_cancellation.cancel(); - if let Some(cancel_wait) = self.should_cancel { - tokio::time::sleep(cancel_wait).await; - - cancellation_token.cancel(); - } } tracing::debug!("ran components");