diff --git a/src/lib.rs b/src/lib.rs index 8e83cf8..5341f4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,10 +4,7 @@ use std::{fmt::Display, sync::Arc}; use anyhow::anyhow; use axum::Router; -use futures::{ - FutureExt, - future::{BoxFuture, join_all}, -}; +use futures::{FutureExt, future::BoxFuture}; // State extraction utilities @@ -86,11 +83,17 @@ where router = plugin.on_setup(router, &state)?; } - let shutdown_fns = self + let shutdown_fns: Vec<_> = self .plugins .into_iter() - .filter_map(|mut p| p.on_shutdown(&state)); - let on_shutdown = join_all(shutdown_fns); + .rev() + .filter_map(|mut p| p.on_shutdown(&state)) + .collect(); + let on_shutdown = async move { + for shutdown_fn in shutdown_fns { + shutdown_fn.await; + } + }; Ok((router, state, on_shutdown)) } @@ -187,7 +190,14 @@ impl AppPlugin for AdHocPlugin { #[cfg(test)] mod tests { - use std::convert::Infallible; + use std::{ + convert::Infallible, + sync::{ + Mutex, + atomic::{AtomicUsize, Ordering}, + }, + task::Poll, + }; use super::*; @@ -227,4 +237,76 @@ mod tests { assert_eq!(state.value.as_str(), "ready"); } + + struct ShutdownOrderPlugin { + name: &'static str, + events: Arc>>, + active_shutdowns: Arc, + } + + impl AppPlugin for ShutdownOrderPlugin { + fn on_shutdown(&mut self, _state: &TestState) -> Option> { + let name = self.name; + let events = Arc::clone(&self.events); + let active_shutdowns = Arc::clone(&self.active_shutdowns); + let mut yielded = false; + + Some(Box::pin(futures::future::poll_fn(move |cx| { + if !yielded { + yielded = true; + let previously_active = active_shutdowns.fetch_add(1, Ordering::SeqCst); + assert_eq!(previously_active, 0, "shutdown hooks ran concurrently"); + events + .lock() + .expect("events lock poisoned") + .push(format!("{name}:start")); + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + events + .lock() + .expect("events lock poisoned") + .push(format!("{name}:finish")); + active_shutdowns.fetch_sub(1, Ordering::SeqCst); + Poll::Ready(()) + }))) + } + } + + #[test] + fn shutdown_hooks_order() { + let events = Arc::new(Mutex::new(Vec::new())); + let active_shutdowns = Arc::new(AtomicUsize::new(0)); + + let app = App::::new() + .register(AdHocPlugin::::new().on_init(async |mut state| { + state.insert(Arc::new(String::from("ready"))); + Ok(state) + })) + .register(ShutdownOrderPlugin { + name: "first", + events: Arc::clone(&events), + active_shutdowns: Arc::clone(&active_shutdowns), + }) + .register(ShutdownOrderPlugin { + name: "second", + events: Arc::clone(&events), + active_shutdowns, + }); + + let (_router, _state, on_shutdown) = + futures::executor::block_on(app.init()).expect("app should initialize"); + futures::executor::block_on(on_shutdown); + + assert_eq!( + *events.lock().expect("events lock poisoned"), + [ + "second:start", + "second:finish", + "first:start", + "first:finish" + ] + ); + } }