1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
// This file is part of the HotShot repository.

// You should have received a copy of the MIT License
// along with the HotShot repository. If not, see <https://mit-license.org/>.

use std::sync::Arc;

use anyhow::Result;
use async_broadcast::{Receiver, Sender};
#[cfg(async_executor_impl = "async-std")]
use async_std::task::{spawn, JoinHandle};
use async_trait::async_trait;
#[cfg(async_executor_impl = "async-std")]
use futures::future::join_all;
#[cfg(async_executor_impl = "tokio")]
use futures::future::try_join_all;
#[cfg(async_executor_impl = "tokio")]
use tokio::task::{spawn, JoinHandle};

/// Trait for events that long-running tasks handle
pub trait TaskEvent: PartialEq {
    /// The shutdown signal for this event type
    ///
    /// Note that this is necessarily uniform across all tasks.
    /// Exiting the task loop is handled by the task spawner, rather than the task individually.
    fn shutdown_event() -> Self;
}

#[async_trait]
/// Type for mutable task state that can be used as the state for a `Task`
pub trait TaskState: Send {
    /// Type of event sent and received by the task
    type Event: TaskEvent + Clone + Send + Sync;

    /// Joins all subtasks.
    async fn cancel_subtasks(&mut self);

    /// Handles an event, providing direct access to the specific channel we received the event on.
    async fn handle_event(
        &mut self,
        event: Arc<Self::Event>,
        _sender: &Sender<Arc<Self::Event>>,
        _receiver: &Receiver<Arc<Self::Event>>,
    ) -> Result<()>;
}

/// A basic task which loops waiting for events to come from `event_receiver`
/// and then handles them using its state
/// It sends events to other `Task`s through `sender`
/// This should be used as the primary building block for long running
/// or medium running tasks (i.e. anything that can't be described as a dependency task)
pub struct Task<S: TaskState> {
    /// The state of the task.  It is fed events from `receiver`
    /// and mutated via `handle_event`.
    state: S,
    /// Sends events all tasks including itself
    sender: Sender<Arc<S::Event>>,
    /// Receives events that are broadcast from any task, including itself
    receiver: Receiver<Arc<S::Event>>,
}

impl<S: TaskState + Send + 'static> Task<S> {
    /// Create a new task
    pub fn new(state: S, sender: Sender<Arc<S::Event>>, receiver: Receiver<Arc<S::Event>>) -> Self {
        Task {
            state,
            sender,
            receiver,
        }
    }

    /// The state of the task, as a boxed dynamic trait object.
    fn boxed_state(self) -> Box<dyn TaskState<Event = S::Event>> {
        Box::new(self.state) as Box<dyn TaskState<Event = S::Event>>
    }

    /// Spawn the task loop, consuming self.  Will continue until
    /// the task reaches some shutdown condition
    pub fn run(mut self) -> JoinHandle<Box<dyn TaskState<Event = S::Event>>> {
        spawn(async move {
            loop {
                match self.receiver.recv_direct().await {
                    Ok(input) => {
                        if *input == S::Event::shutdown_event() {
                            self.state.cancel_subtasks().await;

                            break self.boxed_state();
                        }

                        let _ =
                            S::handle_event(&mut self.state, input, &self.sender, &self.receiver)
                                .await
                                .inspect_err(|e| tracing::info!("{e}"));
                    }
                    Err(e) => {
                        tracing::error!("Failed to receive from event stream Error: {}", e);
                    }
                }
            }
        })
    }
}

#[derive(Default)]
/// A collection of tasks which can handle shutdown
pub struct ConsensusTaskRegistry<EVENT> {
    /// Tasks this registry controls
    task_handles: Vec<JoinHandle<Box<dyn TaskState<Event = EVENT>>>>,
}

impl<EVENT: Send + Sync + Clone + TaskEvent> ConsensusTaskRegistry<EVENT> {
    #[must_use]
    /// Create a new task registry
    pub fn new() -> Self {
        ConsensusTaskRegistry {
            task_handles: vec![],
        }
    }
    /// Add a task to the registry
    pub fn register(&mut self, handle: JoinHandle<Box<dyn TaskState<Event = EVENT>>>) {
        self.task_handles.push(handle);
    }
    /// Try to cancel/abort the task this registry has
    ///
    /// # Panics
    ///
    /// Should not panic, unless awaiting on the JoinHandle in tokio fails.
    pub async fn shutdown(&mut self) {
        let handles = &mut self.task_handles;

        while let Some(handle) = handles.pop() {
            #[cfg(async_executor_impl = "async-std")]
            let mut task_state = handle.await;
            #[cfg(async_executor_impl = "tokio")]
            let mut task_state = handle.await.unwrap();

            task_state.cancel_subtasks().await;
        }
    }
    /// Take a task, run it, and register it
    pub fn run_task<S>(&mut self, task: Task<S>)
    where
        S: TaskState<Event = EVENT> + Send + 'static,
    {
        self.register(task.run());
    }

    /// Wait for the results of all the tasks registered
    /// # Panics
    /// Panics if one of the tasks panicked
    pub async fn join_all(self) -> Vec<Box<dyn TaskState<Event = EVENT>>> {
        #[cfg(async_executor_impl = "async-std")]
        let states = join_all(self.task_handles).await;
        #[cfg(async_executor_impl = "tokio")]
        let states = try_join_all(self.task_handles).await.unwrap();

        states
    }
}

#[derive(Default)]
/// A collection of tasks which can handle shutdown
pub struct NetworkTaskRegistry {
    /// Tasks this registry controls
    pub handles: Vec<JoinHandle<()>>,
}

impl NetworkTaskRegistry {
    #[must_use]
    /// Create a new task registry
    pub fn new() -> Self {
        NetworkTaskRegistry { handles: vec![] }
    }

    #[allow(clippy::unused_async)]
    /// Shuts down all tasks managed by this instance.
    ///
    /// This function waits for all tasks to complete before returning.
    ///
    /// # Panics
    ///
    /// When using the tokio executor, this function will panic if any of the
    /// tasks being joined return an error.
    pub async fn shutdown(&mut self) {
        let handles = std::mem::take(&mut self.handles);
        #[cfg(async_executor_impl = "async-std")]
        join_all(handles).await;
        #[cfg(async_executor_impl = "tokio")]
        try_join_all(handles)
            .await
            .expect("Failed to join all tasks during shutdown");
    }

    /// Add a task to the registry
    pub fn register(&mut self, handle: JoinHandle<()>) {
        self.handles.push(handle);
    }
}