use std::{
collections::{hash_map::DefaultHasher, BTreeMap, HashMap},
future::Future,
hash::{Hash, Hasher},
num::NonZeroUsize,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc,
},
time::Duration,
};
use async_broadcast::{broadcast, InactiveReceiver, Sender};
use async_lock::RwLock;
use async_trait::async_trait;
use futures::{join, select, FutureExt};
#[cfg(feature = "hotshot-testing")]
use hotshot_types::traits::network::{
AsyncGenerator, NetworkReliability, TestableNetworkingImplementation,
};
use hotshot_types::{
boxed_sync,
constants::{
COMBINED_NETWORK_CACHE_SIZE, COMBINED_NETWORK_DELAY_DURATION,
COMBINED_NETWORK_MIN_PRIMARY_FAILURES, COMBINED_NETWORK_PRIMARY_CHECK_INTERVAL,
},
data::ViewNumber,
traits::{
network::{BroadcastDelay, ConnectedNetwork, Topic},
node_implementation::NodeType,
},
BoxSyncFuture,
};
use lru::LruCache;
use parking_lot::RwLock as PlRwLock;
use tokio::{spawn, sync::mpsc::error::TrySendError, time::sleep};
use tracing::{debug, info, warn};
use super::{push_cdn_network::PushCdnNetwork, NetworkError};
use crate::traits::implementations::Libp2pNetwork;
pub fn calculate_hash_of<T: Hash>(t: &T) -> u64 {
let mut s = DefaultHasher::new();
t.hash(&mut s);
s.finish()
}
type DelayedTasksChannelsMap = Arc<RwLock<BTreeMap<u64, (Sender<()>, InactiveReceiver<()>)>>>;
#[derive(Clone)]
pub struct CombinedNetworks<TYPES: NodeType> {
networks: Arc<UnderlyingCombinedNetworks<TYPES>>,
message_cache: Arc<PlRwLock<LruCache<u64, ()>>>,
primary_fail_counter: Arc<AtomicU64>,
primary_down: Arc<AtomicBool>,
delay_duration: Arc<RwLock<Duration>>,
delayed_tasks_channels: DelayedTasksChannelsMap,
no_delay_counter: Arc<AtomicU64>,
}
impl<TYPES: NodeType> CombinedNetworks<TYPES> {
#[must_use]
pub fn new(
primary_network: PushCdnNetwork<TYPES::SignatureKey>,
secondary_network: Libp2pNetwork<TYPES>,
delay_duration: Option<Duration>,
) -> Self {
let networks = Arc::from(UnderlyingCombinedNetworks(
primary_network,
secondary_network,
));
Self {
networks,
message_cache: Arc::new(PlRwLock::new(LruCache::new(
NonZeroUsize::new(COMBINED_NETWORK_CACHE_SIZE).unwrap(),
))),
primary_fail_counter: Arc::new(AtomicU64::new(0)),
primary_down: Arc::new(AtomicBool::new(false)),
delay_duration: Arc::new(RwLock::new(
delay_duration.unwrap_or(Duration::from_millis(COMBINED_NETWORK_DELAY_DURATION)),
)),
delayed_tasks_channels: Arc::default(),
no_delay_counter: Arc::new(AtomicU64::new(0)),
}
}
#[must_use]
pub fn primary(&self) -> &PushCdnNetwork<TYPES::SignatureKey> {
&self.networks.0
}
#[must_use]
pub fn secondary(&self) -> &Libp2pNetwork<TYPES> {
&self.networks.1
}
async fn send_both_networks(
&self,
_message: Vec<u8>,
primary_future: impl Future<Output = Result<(), NetworkError>> + Send + 'static,
secondary_future: impl Future<Output = Result<(), NetworkError>> + Send + 'static,
broadcast_delay: BroadcastDelay,
) -> Result<(), NetworkError> {
let mut primary_failed = false;
if self.primary_down.load(Ordering::Relaxed) {
primary_failed = true;
} else if self.primary_fail_counter.load(Ordering::Relaxed)
> COMBINED_NETWORK_MIN_PRIMARY_FAILURES
{
info!(
"View progression is slower than normally, stop delaying messages on the secondary"
);
self.primary_down.store(true, Ordering::Relaxed);
primary_failed = true;
}
if let Err(e) = primary_future.await {
warn!("Error on primary network: {}", e);
self.primary_fail_counter.fetch_add(1, Ordering::Relaxed);
primary_failed = true;
};
if let (BroadcastDelay::View(view), false) = (broadcast_delay, primary_failed) {
let duration = *self.delay_duration.read().await;
let primary_down = Arc::clone(&self.primary_down);
let primary_fail_counter = Arc::clone(&self.primary_fail_counter);
let mut receiver = self
.delayed_tasks_channels
.write()
.await
.entry(view)
.or_insert_with(|| {
let (s, r) = broadcast(1);
(s, r.deactivate())
})
.1
.activate_cloned();
spawn(async move {
sleep(duration).await;
if receiver.try_recv().is_ok() {
debug!(
"Not sending on secondary after delay, task was canceled in view update"
);
match primary_fail_counter.load(Ordering::Relaxed) {
0u64 => {
primary_down.store(false, Ordering::Relaxed);
debug!("primary_fail_counter reached zero, primary_down set to false");
}
c => {
primary_fail_counter.store(c - 1, Ordering::Relaxed);
debug!("primary_fail_counter set to {:?}", c - 1);
}
}
return Ok(());
}
debug!("Sending on secondary after delay, message possibly has not reached recipient on primary");
primary_fail_counter.fetch_add(1, Ordering::Relaxed);
secondary_future.await
});
Ok(())
} else {
if self.primary_down.load(Ordering::Relaxed) {
match self.no_delay_counter.load(Ordering::Relaxed) {
c if c < COMBINED_NETWORK_PRIMARY_CHECK_INTERVAL => {
self.no_delay_counter.store(c + 1, Ordering::Relaxed);
}
_ => {
debug!(
"Sent on secondary without delay more than {} times,\
try delaying to check primary",
COMBINED_NETWORK_PRIMARY_CHECK_INTERVAL
);
self.no_delay_counter.store(0u64, Ordering::Relaxed);
self.primary_down.store(false, Ordering::Relaxed);
self.primary_fail_counter
.store(COMBINED_NETWORK_MIN_PRIMARY_FAILURES, Ordering::Relaxed);
}
}
}
secondary_future.await
}
}
}
#[derive(Clone)]
pub struct UnderlyingCombinedNetworks<TYPES: NodeType>(
pub PushCdnNetwork<TYPES::SignatureKey>,
pub Libp2pNetwork<TYPES>,
);
#[cfg(feature = "hotshot-testing")]
impl<TYPES: NodeType> TestableNetworkingImplementation<TYPES> for CombinedNetworks<TYPES> {
fn generator(
expected_node_count: usize,
num_bootstrap: usize,
network_id: usize,
da_committee_size: usize,
reliability_config: Option<Box<dyn NetworkReliability>>,
secondary_network_delay: Duration,
) -> AsyncGenerator<Arc<Self>> {
let generators = (
<PushCdnNetwork<TYPES::SignatureKey> as TestableNetworkingImplementation<TYPES>>::generator(
expected_node_count,
num_bootstrap,
network_id,
da_committee_size,
None,
Duration::default(),
),
<Libp2pNetwork<TYPES> as TestableNetworkingImplementation<TYPES>>::generator(
expected_node_count,
num_bootstrap,
network_id,
da_committee_size,
reliability_config,
Duration::default(),
)
);
Box::pin(move |node_id| {
let gen0 = generators.0(node_id);
let gen1 = generators.1(node_id);
Box::pin(async move {
let cdn = gen0.await;
let cdn = Arc::<PushCdnNetwork<TYPES::SignatureKey>>::into_inner(cdn).unwrap();
let p2p = gen1.await;
let underlying_combined = UnderlyingCombinedNetworks(
cdn.clone(),
Arc::<Libp2pNetwork<TYPES>>::unwrap_or_clone(p2p),
);
let message_cache = Arc::new(PlRwLock::new(LruCache::new(
NonZeroUsize::new(COMBINED_NETWORK_CACHE_SIZE).unwrap(),
)));
let combined_network = Self {
networks: Arc::new(underlying_combined),
primary_fail_counter: Arc::new(AtomicU64::new(0)),
primary_down: Arc::new(AtomicBool::new(false)),
message_cache: Arc::clone(&message_cache),
delay_duration: Arc::new(RwLock::new(secondary_network_delay)),
delayed_tasks_channels: Arc::default(),
no_delay_counter: Arc::new(AtomicU64::new(0)),
};
Arc::new(combined_network)
})
})
}
fn in_flight_message_count(&self) -> Option<usize> {
None
}
}
#[async_trait]
impl<TYPES: NodeType> ConnectedNetwork<TYPES::SignatureKey> for CombinedNetworks<TYPES> {
fn pause(&self) {
self.networks.0.pause();
}
fn resume(&self) {
self.networks.0.resume();
}
async fn wait_for_ready(&self) {
join!(
self.primary().wait_for_ready(),
self.secondary().wait_for_ready()
);
}
fn shut_down<'a, 'b>(&'a self) -> BoxSyncFuture<'b, ()>
where
'a: 'b,
Self: 'b,
{
let closure = async move {
join!(self.primary().shut_down(), self.secondary().shut_down());
};
boxed_sync(closure)
}
async fn broadcast_message(
&self,
message: Vec<u8>,
topic: Topic,
broadcast_delay: BroadcastDelay,
) -> Result<(), NetworkError> {
let primary = self.primary().clone();
let secondary = self.secondary().clone();
let primary_message = message.clone();
let secondary_message = message.clone();
let topic_clone = topic.clone();
self.send_both_networks(
message,
async move {
primary
.broadcast_message(primary_message, topic_clone, BroadcastDelay::None)
.await
},
async move {
secondary
.broadcast_message(secondary_message, topic, BroadcastDelay::None)
.await
},
broadcast_delay,
)
.await
}
async fn da_broadcast_message(
&self,
message: Vec<u8>,
recipients: Vec<TYPES::SignatureKey>,
broadcast_delay: BroadcastDelay,
) -> Result<(), NetworkError> {
let primary = self.primary().clone();
let secondary = self.secondary().clone();
let primary_message = message.clone();
let secondary_message = message.clone();
let primary_recipients = recipients.clone();
self.send_both_networks(
message,
async move {
primary
.da_broadcast_message(primary_message, primary_recipients, BroadcastDelay::None)
.await
},
async move {
secondary
.da_broadcast_message(secondary_message, recipients, BroadcastDelay::None)
.await
},
broadcast_delay,
)
.await
}
async fn direct_message(
&self,
message: Vec<u8>,
recipient: TYPES::SignatureKey,
) -> Result<(), NetworkError> {
let primary = self.primary().clone();
let secondary = self.secondary().clone();
let primary_message = message.clone();
let secondary_message = message.clone();
let primary_recipient = recipient.clone();
self.send_both_networks(
message,
async move {
primary
.direct_message(primary_message, primary_recipient)
.await
},
async move { secondary.direct_message(secondary_message, recipient).await },
BroadcastDelay::None,
)
.await
}
async fn vid_broadcast_message(
&self,
messages: HashMap<TYPES::SignatureKey, Vec<u8>>,
) -> Result<(), NetworkError> {
self.networks.0.vid_broadcast_message(messages).await
}
async fn recv_message(&self) -> Result<Vec<u8>, NetworkError> {
loop {
let mut primary_fut = self.primary().recv_message().fuse();
let mut secondary_fut = self.secondary().recv_message().fuse();
let message = select! {
p = primary_fut => p?,
s = secondary_fut => s?,
};
let message_hash = calculate_hash_of(&message);
if self.message_cache.write().put(message_hash, ()).is_none() {
break Ok(message);
}
}
}
fn queue_node_lookup(
&self,
view_number: ViewNumber,
pk: TYPES::SignatureKey,
) -> Result<(), TrySendError<Option<(ViewNumber, TYPES::SignatureKey)>>> {
self.primary().queue_node_lookup(view_number, pk.clone())?;
self.secondary().queue_node_lookup(view_number, pk)
}
async fn update_view<'a, T>(&'a self, view: u64, epoch: u64, membership: &T::Membership)
where
T: NodeType<SignatureKey = TYPES::SignatureKey> + 'a,
{
let delayed_tasks_channels = Arc::clone(&self.delayed_tasks_channels);
spawn(async move {
let mut map_lock = delayed_tasks_channels.write().await;
while let Some((first_view, _)) = map_lock.first_key_value() {
if *first_view < view {
if let Some((_, (sender, _))) = map_lock.pop_first() {
let _ = sender.try_broadcast(());
} else {
break;
}
} else {
break;
}
}
});
self.networks
.1
.update_view::<T>(view, epoch, membership)
.await;
}
fn is_primary_down(&self) -> bool {
self.primary_down.load(Ordering::Relaxed)
}
}