feat: add add_fn to execute immediate lambdas
All checks were successful
continuous-integration/drone/push Build is passing

Signed-off-by: kjuulh <contact@kjuulh.io>
This commit is contained in:
Kasper Juul Hermansen 2024-08-07 15:53:00 +02:00
parent ffa3efd99c
commit 10e2739b6e
Signed by: kjuulh
GPG Key ID: D85D7535F18F35FA
7 changed files with 271 additions and 17 deletions

3
Cargo.lock generated
View File

@ -236,7 +236,7 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]] [[package]]
name = "mad" name = "mad"
version = "0.1.0" version = "0.2.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
@ -247,6 +247,7 @@ dependencies = [
"tokio", "tokio",
"tokio-util", "tokio-util",
"tracing", "tracing",
"tracing-subscriber",
"tracing-test", "tracing-test",
] ]

View File

@ -15,4 +15,5 @@ tokio-util = "0.7.11"
tracing.workspace = true tracing.workspace = true
[dev-dependencies] [dev-dependencies]
tracing-subscriber = "0.3.18"
tracing-test = { version = "0.2.5", features = ["no-env-filter"] } tracing-test = { version = "0.2.5", features = ["no-env-filter"] }

View File

@ -0,0 +1,40 @@
use async_trait::async_trait;
use rand::Rng;
use tokio_util::sync::CancellationToken;
use tracing::Level;
struct WaitServer {}
#[async_trait]
impl mad::Component for WaitServer {
fn name(&self) -> Option<String> {
Some("WaitServer".into())
}
async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> {
let millis_wait = rand::thread_rng().gen_range(500..3000);
tracing::debug!("waiting: {}ms", millis_wait);
// Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely
tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await;
Ok(())
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.init();
mad::Mad::builder()
.add(WaitServer {})
.add(WaitServer {})
.add(WaitServer {})
.add(WaitServer {})
.run()
.await?;
Ok(())
}

View File

@ -0,0 +1,42 @@
use async_trait::async_trait;
use rand::Rng;
use tokio_util::sync::CancellationToken;
use tracing::Level;
struct ErrorServer {}
#[async_trait]
impl mad::Component for ErrorServer {
fn name(&self) -> Option<String> {
Some("ErrorServer".into())
}
async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> {
let millis_wait = rand::thread_rng().gen_range(500..3000);
tracing::debug!("waiting: {}ms", millis_wait);
// Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely
tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await;
Err(mad::MadError::Inner(anyhow::anyhow!("expected error")))
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.init();
// Do note that only the first server which returns an error is guaranteed to be handled. This is because if servers don't respect cancellation, they will be dropped
mad::Mad::builder()
.add(ErrorServer {})
.add(ErrorServer {})
.add(ErrorServer {})
.add(ErrorServer {})
.run()
.await?;
Ok(())
}

View File

@ -0,0 +1,47 @@
use async_trait::async_trait;
use rand::Rng;
use tokio_util::sync::CancellationToken;
use tracing::Level;
struct WaitServer {}
#[async_trait]
impl mad::Component for WaitServer {
fn name(&self) -> Option<String> {
Some("WaitServer".into())
}
async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> {
let millis_wait = rand::thread_rng().gen_range(500..3000);
tracing::debug!("waiting: {}ms", millis_wait);
// Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely
tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await;
Ok(())
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.init();
mad::Mad::builder()
.add(WaitServer {})
.add_fn(|cancel| async move {
let millis_wait = 50;
tracing::debug!("waiting: {}ms", millis_wait);
// Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely
tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await;
Ok(())
})
.run()
.await?;
Ok(())
}

View File

@ -0,0 +1,69 @@
use async_trait::async_trait;
use rand::Rng;
use tokio_util::sync::CancellationToken;
use tracing::Level;
struct WaitServer {}
#[async_trait]
impl mad::Component for WaitServer {
fn name(&self) -> Option<String> {
Some("WaitServer".into())
}
async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> {
let millis_wait = rand::thread_rng().gen_range(500..3000);
tracing::debug!("waiting: {}ms", millis_wait);
// Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely
tokio::time::sleep(std::time::Duration::from_millis(millis_wait)).await;
Ok(())
}
}
struct RespectCancel {}
#[async_trait]
impl mad::Component for RespectCancel {
fn name(&self) -> Option<String> {
Some("RespectCancel".into())
}
async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> {
cancellation.cancelled().await;
tracing::debug!("stopping because job is cancelled");
Ok(())
}
}
struct NeverStopServer {}
#[async_trait]
impl mad::Component for NeverStopServer {
fn name(&self) -> Option<String> {
Some("NeverStopServer".into())
}
async fn run(&self, cancellation: CancellationToken) -> Result<(), mad::MadError> {
// Simulates a server running for some time. Is normally supposed to be futures blocking indefinitely
tokio::time::sleep(std::time::Duration::from_millis(999999999)).await;
Ok(())
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.init();
mad::Mad::builder()
.add(WaitServer {})
.add(NeverStopServer {})
.add(RespectCancel {})
.run()
.await?;
Ok(())
}

View File

@ -49,6 +49,11 @@ pub struct Mad {
should_cancel: Option<std::time::Duration>, should_cancel: Option<std::time::Duration>,
} }
struct CompletionResult {
res: Result<(), MadError>,
name: Option<String>,
}
impl Mad { impl Mad {
pub fn builder() -> Self { pub fn builder() -> Self {
Self { Self {
@ -64,6 +69,16 @@ impl Mad {
self self
} }
pub fn add_fn<F, Fut>(&mut self, f: F) -> &mut Self
where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
Fut: futures::Future<Output = Result<(), MadError>> + Send + 'static,
{
let comp = ClosureComponent { inner: Box::new(f) };
self.add(comp)
}
pub fn cancellation(&mut self, should_cancel: Option<std::time::Duration>) -> &mut Self { pub fn cancellation(&mut self, should_cancel: Option<std::time::Duration>) -> &mut Self {
self.should_cancel = should_cancel; self.should_cancel = should_cancel;
@ -122,22 +137,23 @@ impl Mad {
let cancellation_token = cancellation_token.child_token(); let cancellation_token = cancellation_token.child_token();
let job_cancellation = job_cancellation.child_token(); let job_cancellation = job_cancellation.child_token();
let (error_tx, error_rx) = tokio::sync::mpsc::channel::<Result<(), MadError>>(1); let (error_tx, error_rx) = tokio::sync::mpsc::channel::<CompletionResult>(1);
channels.push(error_rx); channels.push(error_rx);
tokio::spawn(async move { tokio::spawn(async move {
tracing::debug!(component = &comp.name(), "mad running"); let name = comp.name().clone();
tracing::debug!(component = name, "mad running");
tokio::select! { tokio::select! {
_ = cancellation_token.cancelled() => { _ = cancellation_token.cancelled() => {
error_tx.send(Ok(())).await error_tx.send(CompletionResult { res: Ok(()) , name }).await
} }
res = comp.run(job_cancellation) => { res = comp.run(job_cancellation) => {
error_tx.send(res).await error_tx.send(CompletionResult { res , name }).await
} }
_ = tokio::signal::ctrl_c() => { _ = tokio::signal::ctrl_c() => {
error_tx.send(Ok(())).await error_tx.send(CompletionResult { res: Ok(()) , name }).await
} }
} }
}); });
@ -149,19 +165,26 @@ impl Mad {
} }
while let Some(Some(msg)) = futures.next().await { while let Some(Some(msg)) = futures.next().await {
tracing::trace!("received end signal from a component"); match msg.res {
Err(e) => {
tracing::debug!(
error = e.to_string(),
component = msg.name,
"component ran to completion with error"
);
}
Ok(_) => {
tracing::debug!(component = msg.name, "component ran to completion");
}
}
if let Err(e) = msg {
tracing::debug!(error = e.to_string(), "stopping running components");
job_cancellation.cancel(); job_cancellation.cancel();
if let Some(cancel_wait) = self.should_cancel { if let Some(cancel_wait) = self.should_cancel {
tokio::time::sleep(cancel_wait).await; tokio::time::sleep(cancel_wait).await;
cancellation_token.cancel(); cancellation_token.cancel();
} }
} }
}
tracing::debug!("ran components"); tracing::debug!("ran components");
@ -211,3 +234,34 @@ impl<T: Component + Send + Sync + 'static> IntoComponent for T {
Arc::new(self) Arc::new(self)
} }
} }
struct ClosureComponent<F, Fut>
where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
Fut: futures::Future<Output = Result<(), MadError>> + Send + 'static,
{
inner: Box<F>,
}
impl<F, Fut> ClosureComponent<F, Fut>
where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
Fut: futures::Future<Output = Result<(), MadError>> + Send + 'static,
{
pub async fn execute(&self, cancellation_token: CancellationToken) -> Result<(), MadError> {
(*self.inner)(cancellation_token).await?;
Ok(())
}
}
#[async_trait::async_trait]
impl<F, Fut> Component for ClosureComponent<F, Fut>
where
F: Fn(CancellationToken) -> Fut + Send + Sync + 'static,
Fut: futures::Future<Output = Result<(), MadError>> + Send + 'static,
{
async fn run(&self, cancellation_token: CancellationToken) -> Result<(), MadError> {
self.execute(cancellation_token).await
}
}