#![allow(clippy::panic)]
use std::{
collections::{BTreeMap, HashMap, HashSet},
marker::PhantomData,
sync::Arc,
};
use async_broadcast::{broadcast, Receiver, Sender};
use async_lock::RwLock;
use futures::future::join_all;
use hotshot::{
traits::TestableNodeImplementation,
types::{Event, SystemContextHandle},
HotShotInitializer, MarketplaceConfig, SystemContext,
};
use hotshot_example_types::{
auction_results_provider_types::TestAuctionResultsProvider,
block_types::TestBlockHeader,
state_types::{TestInstanceState, TestValidatedState},
storage_types::TestStorage,
};
use hotshot_fakeapi::fake_solver::FakeSolverState;
use hotshot_task_impls::events::HotShotEvent;
use hotshot_types::{
consensus::ConsensusMetricsValue,
constants::EVENT_CHANNEL_SIZE,
data::Leaf2,
simple_certificate::QuorumCertificate2,
traits::{
election::Membership,
network::ConnectedNetwork,
node_implementation::{ConsensusTime, NodeImplementation, NodeType, Versions},
},
HotShotConfig, ValidatorConfig,
};
use tide_disco::Url;
use tokio::{spawn, task::JoinHandle};
#[allow(deprecated)]
use tracing::info;
use super::{
completion_task::CompletionTask,
consistency_task::ConsistencyTask,
overall_safety_task::{OverallSafetyTask, RoundCtx},
txn_task::TxnTask,
};
use crate::{
block_builder::{BuilderTask, TestBuilderImplementation},
completion_task::CompletionTaskDescription,
spinning_task::{ChangeNode, NodeAction, SpinningTask},
test_builder::create_test_handle,
test_launcher::{Network, TestLauncher},
test_task::{TestResult, TestTask},
txn_task::TxnTaskDescription,
view_sync_task::ViewSyncTask,
};
pub trait TaskErr: std::error::Error + Sync + Send + 'static {}
impl<T: std::error::Error + Sync + Send + 'static> TaskErr for T {}
impl<
TYPES: NodeType<
InstanceState = TestInstanceState,
ValidatedState = TestValidatedState,
BlockHeader = TestBlockHeader,
>,
I: TestableNodeImplementation<TYPES>,
V: Versions,
N: ConnectedNetwork<TYPES::SignatureKey>,
> TestRunner<TYPES, I, V, N>
where
I: TestableNodeImplementation<TYPES>,
I: NodeImplementation<
TYPES,
Network = N,
Storage = TestStorage<TYPES>,
AuctionResultsProvider = TestAuctionResultsProvider<TYPES>,
>,
{
#[allow(clippy::too_many_lines)]
pub async fn run_test<B: TestBuilderImplementation<TYPES>>(mut self) {
let (test_sender, test_receiver) = broadcast(EVENT_CHANNEL_SIZE);
let spinning_changes = self
.launcher
.metadata
.spinning_properties
.node_changes
.clone();
let mut late_start_nodes: HashSet<u64> = HashSet::new();
let mut restart_nodes: HashSet<u64> = HashSet::new();
for (_, changes) in &spinning_changes {
for change in changes {
if matches!(change.updown, NodeAction::Up) {
late_start_nodes.insert(change.idx.try_into().unwrap());
}
if matches!(change.updown, NodeAction::RestartDown(_)) {
restart_nodes.insert(change.idx.try_into().unwrap());
}
}
}
self.add_nodes::<B>(
self.launcher.metadata.num_nodes_with_stake,
&late_start_nodes,
&restart_nodes,
)
.await;
let mut event_rxs = vec![];
let mut internal_event_rxs = vec![];
for node in &self.nodes {
let r = node.handle.event_stream_known_impl();
event_rxs.push(r);
}
for node in &self.nodes {
let r = node.handle.internal_event_stream_receiver_known_impl();
internal_event_rxs.push(r);
}
let TestRunner {
launcher,
nodes,
solver_server,
late_start,
next_node_id: _,
_pd: _,
} = self;
let mut task_futs = vec![];
let meta = launcher.metadata.clone();
let handles = Arc::new(RwLock::new(nodes));
let txn_task =
if let TxnTaskDescription::RoundRobinTimeBased(duration) = meta.txn_description {
let txn_task = TxnTask {
handles: Arc::clone(&handles),
next_node_idx: Some(0),
duration,
shutdown_chan: test_receiver.clone(),
};
Some(txn_task)
} else {
None
};
let CompletionTaskDescription::TimeBasedCompletionTaskBuilder(time_based) =
meta.completion_task_description;
let completion_task = CompletionTask {
tx: test_sender.clone(),
rx: test_receiver.clone(),
duration: time_based.duration,
};
let mut changes: BTreeMap<TYPES::View, Vec<ChangeNode>> = BTreeMap::new();
for (view, mut change) in spinning_changes {
changes
.entry(TYPES::View::new(view))
.or_insert_with(Vec::new)
.append(&mut change);
}
let spinning_task_state = SpinningTask {
handles: Arc::clone(&handles),
late_start,
latest_view: None,
changes,
last_decided_leaf: Leaf2::genesis(
&TestValidatedState::default(),
&TestInstanceState::default(),
)
.await,
high_qc: QuorumCertificate2::genesis::<V>(
&TestValidatedState::default(),
&TestInstanceState::default(),
)
.await,
next_epoch_high_qc: None,
async_delay_config: launcher.metadata.async_delay_config,
restart_contexts: HashMap::new(),
channel_generator: launcher.resource_generator.channel_generator,
};
let spinning_task = TestTask::<SpinningTask<TYPES, N, I, V>>::new(
spinning_task_state,
event_rxs.clone(),
test_receiver.clone(),
);
let overall_safety_task_state = OverallSafetyTask {
handles: Arc::clone(&handles),
ctx: RoundCtx::default(),
properties: launcher.metadata.overall_safety_properties.clone(),
error: None,
test_sender,
};
let consistency_task_state = ConsistencyTask {
consensus_leaves: BTreeMap::new(),
safety_properties: launcher.metadata.overall_safety_properties,
ensure_upgrade: launcher.metadata.upgrade_view.is_some(),
validate_transactions: launcher.metadata.validate_transactions,
_pd: PhantomData,
};
let consistency_task = TestTask::<ConsistencyTask<TYPES, V>>::new(
consistency_task_state,
event_rxs.clone(),
test_receiver.clone(),
);
let overall_safety_task = TestTask::<OverallSafetyTask<TYPES, I, V>>::new(
overall_safety_task_state,
event_rxs.clone(),
test_receiver.clone(),
);
let view_sync_task_state = ViewSyncTask {
hit_view_sync: HashSet::new(),
description: launcher.metadata.view_sync_properties,
_pd: PhantomData,
};
let view_sync_task = TestTask::<ViewSyncTask<TYPES, I>>::new(
view_sync_task_state,
internal_event_rxs,
test_receiver.clone(),
);
let nodes = handles.read().await;
for node in &*nodes {
node.network.wait_for_ready().await;
}
for node in &*nodes {
if !late_start_nodes.contains(&node.node_id) {
node.handle.hotshot.start_consensus().await;
}
}
drop(nodes);
for seed in launcher.additional_test_tasks {
let task = TestTask::new(
seed.into_state(Arc::clone(&handles)).await,
event_rxs.clone(),
test_receiver.clone(),
);
task_futs.push(task.run());
}
task_futs.push(overall_safety_task.run());
task_futs.push(consistency_task.run());
task_futs.push(view_sync_task.run());
task_futs.push(spinning_task.run());
let txn_handle = txn_task.map(|txn| txn.run());
let completion_handle = completion_task.run();
let mut error_list = vec![];
let results = join_all(task_futs).await;
for result in results {
match result {
Ok(res) => match res {
TestResult::Pass => {
info!("Task shut down successfully");
}
TestResult::Fail(e) => error_list.push(e),
},
Err(e) => {
tracing::error!("Error Joining the test task {:?}", e);
}
}
}
if let Some(handle) = txn_handle {
handle.abort();
}
if let Some(solver_server) = solver_server {
solver_server.1.abort();
}
let mut nodes = handles.write().await;
for node in &mut *nodes {
node.handle.shut_down().await;
}
tracing::info!("Nodes shutdown");
completion_handle.abort();
assert!(
error_list.is_empty(),
"{}",
error_list
.iter()
.fold("TEST FAILED! Results:".to_string(), |acc, error| {
format!("{acc}\n\n{error:?}")
})
);
}
pub async fn init_builders<B: TestBuilderImplementation<TYPES>>(
&self,
num_nodes: usize,
) -> (Vec<Box<dyn BuilderTask<TYPES>>>, Vec<Url>, Url) {
let config = self.launcher.resource_generator.config.clone();
let mut builder_tasks = Vec::new();
let mut builder_urls = Vec::new();
for metadata in &self.launcher.metadata.builders {
let builder_port = portpicker::pick_unused_port().expect("No free ports");
let builder_url =
Url::parse(&format!("http://localhost:{builder_port}")).expect("Invalid URL");
let builder_task = B::start(
num_nodes,
builder_url.clone(),
B::Config::default(),
metadata.changes.clone(),
)
.await;
builder_tasks.push(builder_task);
builder_urls.push(builder_url);
}
let fallback_builder_port = portpicker::pick_unused_port().expect("No free ports");
let fallback_builder_url =
Url::parse(&format!("http://localhost:{fallback_builder_port}")).expect("Invalid URL");
let fallback_builder_task = B::start(
config.num_nodes_with_stake.into(),
fallback_builder_url.clone(),
B::Config::default(),
self.launcher.metadata.fallback_builder.changes.clone(),
)
.await;
builder_tasks.push(fallback_builder_task);
(builder_tasks, builder_urls, fallback_builder_url)
}
pub async fn add_solver(&mut self, builder_urls: Vec<Url>) {
let solver_error_pct = self.launcher.metadata.solver.error_pct;
let solver_port = portpicker::pick_unused_port().expect("No available ports");
let solver_url: Url = format!("http://localhost:{solver_port}")
.parse()
.expect("Failed to parse solver URL");
let solver_state = FakeSolverState::new(Some(solver_error_pct), builder_urls);
self.solver_server = Some((
solver_url.clone(),
spawn(async move {
solver_state
.run::<TYPES>(solver_url)
.await
.expect("Unable to run solver api");
}),
));
}
pub async fn add_nodes<B: TestBuilderImplementation<TYPES>>(
&mut self,
total: usize,
late_start: &HashSet<u64>,
restart: &HashSet<u64>,
) -> Vec<u64> {
let mut results = vec![];
let config = self.launcher.resource_generator.config.clone();
let temp_memberships = <TYPES as NodeType>::Membership::new(
config.known_nodes_with_stake.clone(),
config.known_da_nodes.clone(),
);
let num_nodes = temp_memberships.total_nodes(TYPES::Epoch::new(0));
let (mut builder_tasks, builder_urls, fallback_builder_url) =
self.init_builders::<B>(num_nodes).await;
if self.launcher.metadata.start_solver {
self.add_solver(builder_urls.clone()).await;
}
let mut uninitialized_nodes = Vec::new();
let mut networks_ready = Vec::new();
for i in 0..total {
let mut config = config.clone();
if let Some(upgrade_view) = self.launcher.metadata.upgrade_view {
config.set_view_upgrade(upgrade_view);
}
let node_id = self.next_node_id;
self.next_node_id += 1;
tracing::debug!("launch node {}", i);
config.builder_urls = builder_urls
.clone()
.try_into()
.expect("Non-empty by construction");
let network = (self.launcher.resource_generator.channel_generator)(node_id).await;
let storage = (self.launcher.resource_generator.storage)(node_id);
let mut marketplace_config =
(self.launcher.resource_generator.marketplace_config)(node_id);
if let Some(solver_server) = &self.solver_server {
let mut new_auction_results_provider =
marketplace_config.auction_results_provider.as_ref().clone();
new_auction_results_provider.broadcast_url = Some(solver_server.0.clone());
marketplace_config.auction_results_provider = new_auction_results_provider.into();
}
marketplace_config.fallback_builder_url = fallback_builder_url.clone();
let network_clone = network.clone();
let networks_ready_future = async move {
network_clone.wait_for_ready().await;
};
networks_ready.push(networks_ready_future);
if late_start.contains(&node_id) {
if self.launcher.metadata.skip_late {
self.late_start.insert(
node_id,
LateStartNode {
network: None,
context: LateNodeContext::UninitializedContext(
LateNodeContextParameters {
storage,
memberships: <TYPES as NodeType>::Membership::new(
config.known_nodes_with_stake.clone(),
config.known_da_nodes.clone(),
),
config,
marketplace_config,
},
),
},
);
} else {
let initializer = HotShotInitializer::<TYPES>::from_genesis::<V>(
TestInstanceState::new(self.launcher.metadata.async_delay_config.clone()),
)
.await
.unwrap();
let is_da = node_id < config.da_staked_committee_size as u64;
let validator_config =
ValidatorConfig::generated_from_seed_indexed([0u8; 32], node_id, 1, is_da);
let hotshot = Self::add_node_with_config(
node_id,
network.clone(),
<TYPES as NodeType>::Membership::new(
config.known_nodes_with_stake.clone(),
config.known_da_nodes.clone(),
),
initializer,
config,
validator_config,
storage,
marketplace_config,
)
.await;
self.late_start.insert(
node_id,
LateStartNode {
network: Some(network),
context: LateNodeContext::InitializedContext(hotshot),
},
);
}
} else {
uninitialized_nodes.push((
node_id,
network,
<TYPES as NodeType>::Membership::new(
config.known_nodes_with_stake.clone(),
config.known_da_nodes.clone(),
),
config,
storage,
marketplace_config,
));
}
results.push(node_id);
}
for node_id in &results {
if restart.contains(node_id) {
self.late_start.insert(
*node_id,
LateStartNode {
network: None,
context: LateNodeContext::Restart,
},
);
}
}
join_all(networks_ready).await;
for (node_id, network, memberships, config, storage, marketplace_config) in
uninitialized_nodes
{
let handle = create_test_handle(
self.launcher.metadata.clone(),
node_id,
network.clone(),
Arc::new(RwLock::new(memberships)),
config.clone(),
storage,
marketplace_config,
)
.await;
match node_id.cmp(&(config.da_staked_committee_size as u64 - 1)) {
std::cmp::Ordering::Less => {
if let Some(task) = builder_tasks.pop() {
task.start(Box::new(handle.event_stream()))
}
}
std::cmp::Ordering::Equal => {
while let Some(task) = builder_tasks.pop() {
task.start(Box::new(handle.event_stream()))
}
}
std::cmp::Ordering::Greater => {}
}
self.nodes.push(Node {
node_id,
network,
handle,
});
}
results
}
#[allow(clippy::too_many_arguments)]
pub async fn add_node_with_config(
node_id: u64,
network: Network<TYPES, I>,
memberships: TYPES::Membership,
initializer: HotShotInitializer<TYPES>,
config: HotShotConfig<TYPES::SignatureKey>,
validator_config: ValidatorConfig<TYPES::SignatureKey>,
storage: I::Storage,
marketplace_config: MarketplaceConfig<TYPES, I>,
) -> Arc<SystemContext<TYPES, I, V>> {
let private_key = validator_config.private_key.clone();
let public_key = validator_config.public_key.clone();
SystemContext::new(
public_key,
private_key,
node_id,
config,
Arc::new(RwLock::new(memberships)),
network,
initializer,
ConsensusMetricsValue::default(),
storage,
marketplace_config,
)
.await
}
#[allow(clippy::too_many_arguments, clippy::type_complexity)]
pub async fn add_node_with_config_and_channels(
node_id: u64,
network: Network<TYPES, I>,
memberships: Arc<RwLock<TYPES::Membership>>,
initializer: HotShotInitializer<TYPES>,
config: HotShotConfig<TYPES::SignatureKey>,
validator_config: ValidatorConfig<TYPES::SignatureKey>,
storage: I::Storage,
marketplace_config: MarketplaceConfig<TYPES, I>,
internal_channel: (
Sender<Arc<HotShotEvent<TYPES>>>,
Receiver<Arc<HotShotEvent<TYPES>>>,
),
external_channel: (Sender<Event<TYPES>>, Receiver<Event<TYPES>>),
) -> Arc<SystemContext<TYPES, I, V>> {
let private_key = validator_config.private_key.clone();
let public_key = validator_config.public_key.clone();
SystemContext::new_from_channels(
public_key,
private_key,
node_id,
config,
memberships,
network,
initializer,
ConsensusMetricsValue::default(),
storage,
marketplace_config,
internal_channel,
external_channel,
)
}
}
pub struct Node<TYPES: NodeType, I: TestableNodeImplementation<TYPES>, V: Versions> {
pub node_id: u64,
pub network: Network<TYPES, I>,
pub handle: SystemContextHandle<TYPES, I, V>,
}
pub struct LateNodeContextParameters<TYPES: NodeType, I: TestableNodeImplementation<TYPES>> {
pub storage: I::Storage,
pub memberships: TYPES::Membership,
pub config: HotShotConfig<TYPES::SignatureKey>,
pub marketplace_config: MarketplaceConfig<TYPES, I>,
}
#[allow(clippy::large_enum_variant)]
pub enum LateNodeContext<TYPES: NodeType, I: TestableNodeImplementation<TYPES>, V: Versions> {
InitializedContext(Arc<SystemContext<TYPES, I, V>>),
UninitializedContext(LateNodeContextParameters<TYPES, I>),
Restart,
}
pub struct LateStartNode<TYPES: NodeType, I: TestableNodeImplementation<TYPES>, V: Versions> {
pub network: Option<Network<TYPES, I>>,
pub context: LateNodeContext<TYPES, I, V>,
}
pub struct TestRunner<
TYPES: NodeType,
I: TestableNodeImplementation<TYPES>,
V: Versions,
N: ConnectedNetwork<TYPES::SignatureKey>,
> {
pub(crate) launcher: TestLauncher<TYPES, I, V>,
pub(crate) nodes: Vec<Node<TYPES, I, V>>,
pub(crate) solver_server: Option<(Url, JoinHandle<()>)>,
pub(crate) late_start: HashMap<u64, LateStartNode<TYPES, I, V>>,
pub(crate) next_node_id: u64,
pub(crate) _pd: PhantomData<N>,
}