use core::time::Duration;
use std::{
collections::BTreeSet,
fmt::Debug,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use async_compatibility_layer::{
art::async_spawn,
channel::{bounded, BoundedStream, Receiver, SendError, Sender},
};
use async_lock::{Mutex, RwLock};
use async_trait::async_trait;
use dashmap::DashMap;
use futures::StreamExt;
use hotshot_types::{
boxed_sync,
traits::{
network::{
AsyncGenerator, BroadcastDelay, ConnectedNetwork, TestableNetworkingImplementation,
Topic,
},
node_implementation::NodeType,
signature_key::SignatureKey,
},
BoxSyncFuture,
};
use rand::Rng;
use tracing::{debug, error, info, info_span, instrument, trace, warn, Instrument};
use super::{NetworkError, NetworkReliability};
#[derive(custom_debug::Debug)]
pub struct MasterMap<K: SignatureKey> {
#[debug(skip)]
map: DashMap<K, MemoryNetwork<K>>,
subscribed_map: DashMap<Topic, Vec<(K, MemoryNetwork<K>)>>,
id: u64,
}
impl<K: SignatureKey> MasterMap<K> {
#[must_use]
pub fn new() -> Arc<MasterMap<K>> {
Arc::new(MasterMap {
map: DashMap::new(),
subscribed_map: DashMap::new(),
id: rand::thread_rng().gen(),
})
}
}
#[derive(Debug)]
struct MemoryNetworkInner<K: SignatureKey> {
input: RwLock<Option<Sender<Vec<u8>>>>,
output: Mutex<Receiver<Vec<u8>>>,
master_map: Arc<MasterMap<K>>,
in_flight_message_count: AtomicUsize,
reliability_config: Option<Box<dyn NetworkReliability>>,
}
#[derive(Clone)]
pub struct MemoryNetwork<K: SignatureKey> {
inner: Arc<MemoryNetworkInner<K>>,
}
impl<K: SignatureKey> Debug for MemoryNetwork<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryNetwork")
.field("inner", &"inner")
.finish()
}
}
impl<K: SignatureKey> MemoryNetwork<K> {
pub fn new(
pub_key: &K,
master_map: &Arc<MasterMap<K>>,
subscribed_topics: &[Topic],
reliability_config: Option<Box<dyn NetworkReliability>>,
) -> MemoryNetwork<K> {
info!("Attaching new MemoryNetwork");
let (input, task_recv) = bounded(128);
let (task_send, output) = bounded(128);
let in_flight_message_count = AtomicUsize::new(0);
trace!("Channels open, spawning background task");
async_spawn(
async move {
debug!("Starting background task");
let mut task_stream: BoundedStream<Vec<u8>> = task_recv.into_stream();
trace!("Entering processing loop");
while let Some(vec) = task_stream.next().await {
trace!(?vec, "Incoming message");
let ts = task_send.clone();
let res = ts.send(vec).await;
if res.is_ok() {
trace!("Passed message to output queue");
} else {
error!("Output queue receivers are shutdown");
}
}
}
.instrument(info_span!("MemoryNetwork Background task", map = ?master_map)),
);
trace!("Notifying other networks of the new connected peer");
trace!("Task spawned, creating MemoryNetwork");
let mn = MemoryNetwork {
inner: Arc::new(MemoryNetworkInner {
input: RwLock::new(Some(input)),
output: Mutex::new(output),
master_map: Arc::clone(master_map),
in_flight_message_count,
reliability_config,
}),
};
master_map.map.insert(pub_key.clone(), mn.clone());
for topic in subscribed_topics {
master_map
.subscribed_map
.entry(topic.clone())
.or_default()
.push((pub_key.clone(), mn.clone()));
}
mn
}
async fn input(&self, message: Vec<u8>) -> Result<(), SendError<Vec<u8>>> {
self.inner
.in_flight_message_count
.fetch_add(1, Ordering::Relaxed);
let input = self.inner.input.read().await;
if let Some(input) = &*input {
input.send(message).await
} else {
Err(SendError(message))
}
}
}
impl<TYPES: NodeType> TestableNetworkingImplementation<TYPES>
for MemoryNetwork<TYPES::SignatureKey>
{
fn generator(
_expected_node_count: usize,
_num_bootstrap: usize,
_network_id: usize,
da_committee_size: usize,
_is_da: bool,
reliability_config: Option<Box<dyn NetworkReliability>>,
_secondary_network_delay: Duration,
) -> AsyncGenerator<Arc<Self>> {
let master: Arc<_> = MasterMap::new();
Box::pin(move |node_id| {
let privkey = TYPES::SignatureKey::generated_from_seed_indexed([0u8; 32], node_id).1;
let pubkey = TYPES::SignatureKey::from_private(&privkey);
let subscribed_topics = if node_id < da_committee_size as u64 {
vec![Topic::Da, Topic::Global]
} else {
vec![Topic::Global]
};
let net = MemoryNetwork::new(
&pubkey,
&master,
&subscribed_topics,
reliability_config.clone(),
);
Box::pin(async move { net.into() })
})
}
fn in_flight_message_count(&self) -> Option<usize> {
Some(self.inner.in_flight_message_count.load(Ordering::Relaxed))
}
}
#[async_trait]
impl<K: SignatureKey + 'static> ConnectedNetwork<K> for MemoryNetwork<K> {
#[instrument(name = "MemoryNetwork::ready_blocking")]
async fn wait_for_ready(&self) {}
fn pause(&self) {
unimplemented!("Pausing not implemented for the Memory network");
}
fn resume(&self) {
unimplemented!("Resuming not implemented for the Memory network");
}
#[instrument(name = "MemoryNetwork::shut_down")]
fn shut_down<'a, 'b>(&'a self) -> BoxSyncFuture<'b, ()>
where
'a: 'b,
Self: 'b,
{
let closure = async move {
*self.inner.input.write().await = None;
};
boxed_sync(closure)
}
#[instrument(name = "MemoryNetwork::broadcast_message")]
async fn broadcast_message(
&self,
message: Vec<u8>,
topic: Topic,
_broadcast_delay: BroadcastDelay,
) -> Result<(), NetworkError> {
trace!(?message, "Broadcasting message");
for node in self
.inner
.master_map
.subscribed_map
.entry(topic)
.or_default()
.iter()
{
let (key, node) = node;
trace!(?key, "Sending message to node");
if let Some(ref config) = &self.inner.reliability_config {
{
let node2 = node.clone();
let fut = config.chaos_send_msg(
message.clone(),
Arc::new(move |msg: Vec<u8>| {
let node3 = (node2).clone();
boxed_sync(async move {
let _res = node3.input(msg).await;
})
}),
);
async_spawn(fut);
}
} else {
let res = node.input(message.clone()).await;
match res {
Ok(()) => {
trace!(?key, "Delivered message to remote");
}
Err(e) => {
warn!(?e, ?key, "Error sending broadcast message to node");
}
}
}
}
Ok(())
}
#[instrument(name = "MemoryNetwork::da_broadcast_message")]
async fn da_broadcast_message(
&self,
message: Vec<u8>,
recipients: BTreeSet<K>,
broadcast_delay: BroadcastDelay,
) -> Result<(), NetworkError> {
let topic = self
.inner
.master_map
.subscribed_map
.iter()
.find(|v| v.value().iter().all(|(k, _)| recipients.contains(k)))
.map(|v| v.key().clone())
.ok_or(NetworkError::NotFound)?;
self.broadcast_message(message, topic, broadcast_delay)
.await
}
#[instrument(name = "MemoryNetwork::direct_message")]
async fn direct_message(&self, message: Vec<u8>, recipient: K) -> Result<(), NetworkError> {
trace!("Message bincoded, finding recipient");
if let Some(node) = self.inner.master_map.map.get(&recipient) {
let node = node.value().clone();
if let Some(ref config) = &self.inner.reliability_config {
{
let fut = config.chaos_send_msg(
message.clone(),
Arc::new(move |msg: Vec<u8>| {
let node2 = node.clone();
boxed_sync(async move {
let _res = node2.input(msg).await;
})
}),
);
async_spawn(fut);
}
Ok(())
} else {
let res = node.input(message).await;
match res {
Ok(()) => {
trace!(?recipient, "Delivered message to remote");
Ok(())
}
Err(e) => {
warn!(?e, ?recipient, "Error delivering direct message");
Err(NetworkError::CouldNotDeliver)
}
}
}
} else {
warn!(
"{:#?} {:#?} {:#?}",
recipient, self.inner.master_map.map, "Node does not exist in map"
);
Err(NetworkError::NoSuchNode)
}
}
#[instrument(name = "MemoryNetwork::recv_messages", skip_all)]
async fn recv_message(&self) -> Result<Vec<u8>, NetworkError> {
let ret = self
.inner
.output
.lock()
.await
.recv()
.await
.map_err(|_x| NetworkError::ShutDown)?;
self.inner
.in_flight_message_count
.fetch_sub(1, Ordering::Relaxed);
Ok(ret)
}
}