#[cfg(feature = "hotshot-testing")]
use std::sync::atomic::{AtomicBool, Ordering};
use std::{collections::VecDeque, marker::PhantomData, sync::Arc};
#[cfg(feature = "hotshot-testing")]
use std::{path::Path, time::Duration};
use async_trait::async_trait;
use bincode::config::Options;
use cdn_broker::reexports::{
connection::protocols::{Tcp, TcpTls},
def::{hook::NoMessageHook, ConnectionDef, RunDef, Topic as TopicTrait},
discovery::{Embedded, Redis},
};
#[cfg(feature = "hotshot-testing")]
use cdn_broker::{Broker, Config as BrokerConfig};
pub use cdn_client::reexports::crypto::signature::KeyPair;
use cdn_client::{
reexports::{
crypto::signature::{Serializable, SignatureScheme},
message::{Broadcast, Direct, Message as PushCdnMessage},
},
Client, Config as ClientConfig,
};
#[cfg(feature = "hotshot-testing")]
use cdn_marshal::{Config as MarshalConfig, Marshal};
#[cfg(feature = "hotshot-testing")]
use hotshot_types::traits::network::{
AsyncGenerator, NetworkReliability, TestableNetworkingImplementation,
};
use hotshot_types::{
boxed_sync,
data::ViewNumber,
traits::{
metrics::{Counter, Metrics, NoMetrics},
network::{BroadcastDelay, ConnectedNetwork, Topic as HotShotTopic},
node_implementation::NodeType,
signature_key::SignatureKey,
},
utils::bincode_opts,
BoxSyncFuture,
};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use parking_lot::Mutex;
#[cfg(feature = "hotshot-testing")]
use rand::{rngs::StdRng, RngCore, SeedableRng};
use tokio::{spawn, sync::mpsc::error::TrySendError, time::sleep};
use tracing::error;
use super::NetworkError;
#[derive(Clone)]
pub struct CdnMetricsValue {
pub num_failed_messages: Box<dyn Counter>,
}
impl CdnMetricsValue {
pub fn new(metrics: &dyn Metrics) -> Self {
let subgroup = metrics.subgroup("cdn".into());
Self {
num_failed_messages: subgroup.create_counter("num_failed_messages".into(), None),
}
}
}
impl Default for CdnMetricsValue {
fn default() -> Self {
Self::new(&*NoMetrics::boxed())
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct WrappedSignatureKey<T: SignatureKey + 'static>(pub T);
impl<T: SignatureKey> SignatureScheme for WrappedSignatureKey<T> {
type PrivateKey = T::PrivateKey;
type PublicKey = Self;
fn sign(
private_key: &Self::PrivateKey,
namespace: &str,
message: &[u8],
) -> anyhow::Result<Vec<u8>> {
let message = [namespace.as_bytes(), message].concat();
let signature = T::sign(private_key, &message)?;
Ok(bincode_opts().serialize(&signature)?)
}
fn verify(
public_key: &Self::PublicKey,
namespace: &str,
message: &[u8],
signature: &[u8],
) -> bool {
let signature: T::PureAssembledSignatureType = match bincode_opts().deserialize(signature) {
Ok(key) => key,
Err(_) => return false,
};
let message = [namespace.as_bytes(), message].concat();
public_key.0.validate(&signature, &message)
}
}
impl<T: SignatureKey> Serializable for WrappedSignatureKey<T> {
fn serialize(&self) -> anyhow::Result<Vec<u8>> {
Ok(self.0.to_bytes())
}
fn deserialize(serialized: &[u8]) -> anyhow::Result<Self> {
Ok(WrappedSignatureKey(T::from_bytes(serialized)?))
}
}
pub struct ProductionDef<K: SignatureKey + 'static>(PhantomData<K>);
impl<K: SignatureKey + 'static> RunDef for ProductionDef<K> {
type User = UserDef<K>;
type Broker = BrokerDef<K>;
type DiscoveryClientType = Redis;
type Topic = Topic;
}
pub struct UserDef<K: SignatureKey + 'static>(PhantomData<K>);
impl<K: SignatureKey + 'static> ConnectionDef for UserDef<K> {
type Scheme = WrappedSignatureKey<K>;
type Protocol = TcpTls;
type MessageHook = NoMessageHook;
}
pub struct BrokerDef<K: SignatureKey + 'static>(PhantomData<K>);
impl<K: SignatureKey> ConnectionDef for BrokerDef<K> {
type Scheme = WrappedSignatureKey<K>;
type Protocol = Tcp;
type MessageHook = NoMessageHook;
}
#[derive(Clone)]
pub struct ClientDef<K: SignatureKey + 'static>(PhantomData<K>);
impl<K: SignatureKey> ConnectionDef for ClientDef<K> {
type Scheme = WrappedSignatureKey<K>;
type Protocol = TcpTls;
type MessageHook = NoMessageHook;
}
pub struct TestingDef<K: SignatureKey + 'static>(PhantomData<K>);
impl<K: SignatureKey + 'static> RunDef for TestingDef<K> {
type User = UserDef<K>;
type Broker = BrokerDef<K>;
type DiscoveryClientType = Embedded;
type Topic = Topic;
}
#[derive(Clone)]
pub struct PushCdnNetwork<K: SignatureKey + 'static> {
client: Client<ClientDef<K>>,
metrics: Arc<CdnMetricsValue>,
internal_queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
public_key: K,
#[cfg(feature = "hotshot-testing")]
is_paused: Arc<AtomicBool>,
}
#[repr(u8)]
#[derive(IntoPrimitive, TryFromPrimitive, Clone, PartialEq, Eq)]
pub enum Topic {
Global = 0,
Da = 1,
}
impl TopicTrait for Topic {}
impl<K: SignatureKey + 'static> PushCdnNetwork<K> {
pub fn new(
marshal_endpoint: String,
topics: Vec<Topic>,
keypair: KeyPair<WrappedSignatureKey<K>>,
metrics: CdnMetricsValue,
) -> anyhow::Result<Self> {
let config = ClientConfig {
endpoint: marshal_endpoint,
subscribed_topics: topics.into_iter().map(|t| t as u8).collect(),
keypair: keypair.clone(),
use_local_authority: true,
};
let client = Client::new(config);
Ok(Self {
client,
metrics: Arc::from(metrics),
internal_queue: Arc::new(Mutex::new(VecDeque::new())),
public_key: keypair.public_key.0,
#[cfg(feature = "hotshot-testing")]
is_paused: Arc::from(AtomicBool::new(false)),
})
}
async fn broadcast_message(&self, message: Vec<u8>, topic: Topic) -> Result<(), NetworkError> {
#[cfg(feature = "hotshot-testing")]
if self.is_paused.load(Ordering::Relaxed) {
return Ok(());
}
if let Err(err) = self
.client
.send_broadcast_message(vec![topic as u8], message)
.await
{
return Err(NetworkError::MessageReceiveError(format!(
"failed to send broadcast message: {err}"
)));
};
Ok(())
}
}
#[cfg(feature = "hotshot-testing")]
impl<TYPES: NodeType> TestableNetworkingImplementation<TYPES>
for PushCdnNetwork<TYPES::SignatureKey>
{
#[allow(clippy::too_many_lines)]
fn generator(
_expected_node_count: usize,
_num_bootstrap: usize,
_network_id: usize,
da_committee_size: usize,
_reliability_config: Option<Box<dyn NetworkReliability>>,
_secondary_network_delay: Duration,
) -> AsyncGenerator<Arc<Self>> {
let (broker_public_key, broker_private_key) =
TYPES::SignatureKey::generated_from_seed_indexed([0u8; 32], 1337);
let temp_dir = std::env::temp_dir();
let discovery_endpoint = temp_dir
.join(Path::new(&format!(
"test-{}.sqlite",
StdRng::from_entropy().next_u64()
)))
.to_string_lossy()
.into_owned();
let public_address_1 = format!(
"127.0.0.1:{}",
portpicker::pick_unused_port().expect("could not find an open port")
);
let public_address_2 = format!(
"127.0.0.1:{}",
portpicker::pick_unused_port().expect("could not find an open port")
);
for i in 0..2 {
let private_port = portpicker::pick_unused_port().expect("could not find an open port");
let private_address = format!("127.0.0.1:{private_port}");
let (public_address, other_public_address) = if i == 0 {
(public_address_1.clone(), public_address_2.clone())
} else {
(public_address_2.clone(), public_address_1.clone())
};
let broker_identifier = format!("{public_address}/{public_address}");
let other_broker_identifier = format!("{other_public_address}/{other_public_address}");
let config: BrokerConfig<TestingDef<TYPES::SignatureKey>> = BrokerConfig {
public_advertise_endpoint: public_address.clone(),
public_bind_endpoint: public_address,
private_advertise_endpoint: private_address.clone(),
private_bind_endpoint: private_address,
metrics_bind_endpoint: None,
keypair: KeyPair {
public_key: WrappedSignatureKey(broker_public_key.clone()),
private_key: broker_private_key.clone(),
},
discovery_endpoint: discovery_endpoint.clone(),
user_message_hook: NoMessageHook,
broker_message_hook: NoMessageHook,
ca_cert_path: None,
ca_key_path: None,
global_memory_pool_size: Some(1024 * 1024 * 1024),
};
spawn(async move {
let broker: Broker<TestingDef<TYPES::SignatureKey>> =
Broker::new(config).await.expect("broker failed to start");
if other_broker_identifier > broker_identifier {
sleep(Duration::from_secs(2)).await;
}
if let Err(err) = broker.start().await {
error!("broker stopped: {err}");
}
});
}
let marshal_port = portpicker::pick_unused_port().expect("could not find an open port");
let marshal_endpoint = format!("127.0.0.1:{marshal_port}");
let marshal_config = MarshalConfig {
bind_endpoint: marshal_endpoint.clone(),
discovery_endpoint,
metrics_bind_endpoint: None,
ca_cert_path: None,
ca_key_path: None,
global_memory_pool_size: Some(1024 * 1024 * 1024),
};
spawn(async move {
let marshal: Marshal<TestingDef<TYPES::SignatureKey>> = Marshal::new(marshal_config)
.await
.expect("failed to spawn marshal");
if let Err(err) = marshal.start().await {
error!("marshal stopped: {err}");
}
});
Box::pin({
move |node_id| {
let marshal_endpoint = marshal_endpoint.clone();
Box::pin(async move {
let private_key =
TYPES::SignatureKey::generated_from_seed_indexed([0u8; 32], node_id).1;
let public_key = TYPES::SignatureKey::from_private(&private_key);
let topics = if node_id < da_committee_size as u64 {
vec![Topic::Da as u8, Topic::Global as u8]
} else {
vec![Topic::Global as u8]
};
let client_config: ClientConfig<ClientDef<TYPES::SignatureKey>> =
ClientConfig {
keypair: KeyPair {
public_key: WrappedSignatureKey(public_key.clone()),
private_key,
},
subscribed_topics: topics,
endpoint: marshal_endpoint,
use_local_authority: true,
};
Arc::new(PushCdnNetwork {
client: Client::new(client_config),
metrics: Arc::new(CdnMetricsValue::default()),
internal_queue: Arc::new(Mutex::new(VecDeque::new())),
public_key,
#[cfg(feature = "hotshot-testing")]
is_paused: Arc::from(AtomicBool::new(false)),
})
})
}
})
}
fn in_flight_message_count(&self) -> Option<usize> {
None
}
}
#[async_trait]
impl<K: SignatureKey + 'static> ConnectedNetwork<K> for PushCdnNetwork<K> {
fn pause(&self) {
#[cfg(feature = "hotshot-testing")]
self.is_paused.store(true, Ordering::Relaxed);
}
fn resume(&self) {
#[cfg(feature = "hotshot-testing")]
self.is_paused.store(false, Ordering::Relaxed);
}
async fn wait_for_ready(&self) {
let _ = self.client.ensure_initialized().await;
}
fn shut_down<'a, 'b>(&'a self) -> BoxSyncFuture<'b, ()>
where
'a: 'b,
Self: 'b,
{
boxed_sync(async move { self.client.close().await })
}
async fn broadcast_message(
&self,
message: Vec<u8>,
topic: HotShotTopic,
_broadcast_delay: BroadcastDelay,
) -> Result<(), NetworkError> {
#[cfg(feature = "hotshot-testing")]
if self.is_paused.load(Ordering::Relaxed) {
return Ok(());
}
self.broadcast_message(message, topic.into())
.await
.inspect_err(|_e| {
self.metrics.num_failed_messages.add(1);
})
}
async fn da_broadcast_message(
&self,
message: Vec<u8>,
_recipients: Vec<K>,
_broadcast_delay: BroadcastDelay,
) -> Result<(), NetworkError> {
#[cfg(feature = "hotshot-testing")]
if self.is_paused.load(Ordering::Relaxed) {
return Ok(());
}
self.broadcast_message(message, Topic::Da)
.await
.inspect_err(|_e| {
self.metrics.num_failed_messages.add(1);
})
}
async fn direct_message(&self, message: Vec<u8>, recipient: K) -> Result<(), NetworkError> {
#[cfg(feature = "hotshot-testing")]
if self.is_paused.load(Ordering::Relaxed) {
return Ok(());
}
if recipient == self.public_key {
self.internal_queue.lock().push_back(message);
return Ok(());
}
if let Err(e) = self
.client
.send_direct_message(&WrappedSignatureKey(recipient), message)
.await
{
self.metrics.num_failed_messages.add(1);
return Err(NetworkError::MessageSendError(format!(
"failed to send direct message: {e}"
)));
};
Ok(())
}
async fn recv_message(&self) -> Result<Vec<u8>, NetworkError> {
if let Some(message) = self.internal_queue.lock().pop_front() {
return Ok(message);
}
let message = self.client.receive_message().await;
#[cfg(feature = "hotshot-testing")]
if self.is_paused.load(Ordering::Relaxed) {
sleep(Duration::from_millis(100)).await;
return Ok(vec![]);
}
let message = match message {
Ok(message) => message,
Err(error) => {
return Err(NetworkError::MessageReceiveError(format!(
"failed to receive message: {error}"
)));
}
};
let (PushCdnMessage::Broadcast(Broadcast { message, topics: _ })
| PushCdnMessage::Direct(Direct {
message,
recipient: _,
})) = message
else {
return Ok(vec![]);
};
Ok(message)
}
fn queue_node_lookup(
&self,
_view_number: ViewNumber,
_pk: K,
) -> Result<(), TrySendError<Option<(ViewNumber, K)>>> {
Ok(())
}
}
impl From<HotShotTopic> for Topic {
fn from(topic: HotShotTopic) -> Self {
match topic {
HotShotTopic::Global => Topic::Global,
HotShotTopic::Da => Topic::Da,
}
}
}