use std::{sync::Arc, time::Duration};
use anyhow::Result;
use async_broadcast::{Receiver, Sender};
use async_lock::RwLock;
use async_trait::async_trait;
use futures::future::select_all;
use hotshot::{
traits::TestableNodeImplementation,
types::{Event, Message},
};
use hotshot_task_impls::{events::HotShotEvent, network::NetworkMessageTaskState};
use hotshot_types::{
message::UpgradeLock,
traits::{
network::ConnectedNetwork,
node_implementation::{NodeType, Versions},
},
};
use tokio::task::JoinHandle;
use tokio::{
spawn,
time::{sleep, timeout},
};
use tracing::error;
use crate::test_runner::Node;
pub enum TestResult {
Pass,
Fail(Box<dyn std::fmt::Debug + Send + Sync>),
}
#[async_trait]
pub trait TestTaskState: Send {
type Event: Clone + Send + Sync;
async fn handle_event(&mut self, (event, id): (Self::Event, usize)) -> Result<()>;
async fn check(&self) -> TestResult;
}
pub type AnyTestTaskState<TYPES> =
Box<dyn TestTaskState<Event = hotshot_types::event::Event<TYPES>> + Send + Sync>;
#[async_trait]
impl<TYPES: NodeType> TestTaskState for AnyTestTaskState<TYPES> {
type Event = Event<TYPES>;
async fn handle_event(&mut self, event: (Self::Event, usize)) -> Result<()> {
(**self).handle_event(event).await
}
async fn check(&self) -> TestResult {
(**self).check().await
}
}
#[async_trait]
pub trait TestTaskStateSeed<TYPES, I, V>: Send
where
TYPES: NodeType,
I: TestableNodeImplementation<TYPES>,
V: Versions,
{
async fn into_state(
self: Box<Self>,
handles: Arc<RwLock<Vec<Node<TYPES, I, V>>>>,
) -> AnyTestTaskState<TYPES>;
}
pub struct TestTask<S: TestTaskState> {
state: S,
receivers: Vec<Receiver<S::Event>>,
test_receiver: Receiver<TestEvent>,
}
#[derive(Clone, Debug)]
pub enum TestEvent {
Shutdown,
}
impl<S: TestTaskState + Send + 'static> TestTask<S> {
pub fn new(
state: S,
receivers: Vec<Receiver<S::Event>>,
test_receiver: Receiver<TestEvent>,
) -> Self {
TestTask {
state,
receivers,
test_receiver,
}
}
pub fn run(mut self) -> JoinHandle<TestResult> {
spawn(async move {
loop {
if let Ok(TestEvent::Shutdown) = self.test_receiver.try_recv() {
break self.state.check().await;
}
let mut messages = Vec::new();
for receiver in &mut self.receivers {
messages.push(receiver.recv());
}
match timeout(Duration::from_millis(2500), select_all(messages)).await {
Ok((Ok(input), id, _)) => {
let _ = S::handle_event(&mut self.state, (input, id))
.await
.inspect_err(|e| tracing::error!("{e}"));
}
Ok((Err(e), _id, _)) => {
error!("Error from one channel in test task {:?}", e);
sleep(Duration::from_millis(4000)).await;
}
_ => {}
};
}
})
}
}
pub async fn add_network_message_test_task<
TYPES: NodeType,
V: Versions,
NET: ConnectedNetwork<TYPES::SignatureKey>,
>(
internal_event_stream: Sender<Arc<HotShotEvent<TYPES>>>,
external_event_stream: Sender<Event<TYPES>>,
upgrade_lock: UpgradeLock<TYPES, V>,
channel: Arc<NET>,
public_key: TYPES::SignatureKey,
) -> JoinHandle<()> {
let net = Arc::clone(&channel);
let network_state: NetworkMessageTaskState<_> = NetworkMessageTaskState {
internal_event_stream: internal_event_stream.clone(),
external_event_stream: external_event_stream.clone(),
public_key,
};
let network = Arc::clone(&net);
let mut state = network_state.clone();
spawn(async move {
loop {
let message = match network.recv_message().await {
Ok(message) => message,
Err(e) => {
error!("Failed to receive message: {:?}", e);
continue;
}
};
let deserialized_message: Message<TYPES> =
match upgrade_lock.deserialize(&message).await {
Ok(message) => message,
Err(e) => {
tracing::error!("Failed to deserialize message: {:?}", e);
continue;
}
};
state.handle_message(deserialized_message).await;
}
})
}