use std::{
collections::BTreeMap,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use anyhow::{ensure, Result};
use async_broadcast::{Receiver, Sender};
use async_compatibility_layer::art::{async_sleep, async_spawn, async_timeout};
#[cfg(async_executor_impl = "async-std")]
use async_std::task::JoinHandle;
use async_trait::async_trait;
use committable::Committable;
use hotshot_task::task::TaskState;
use hotshot_types::{
consensus::OuterConsensus,
message::{DaConsensusMessage, DataMessage, Message, MessageKind, SequencingMessage},
traits::{
election::Membership,
network::{ConnectedNetwork, DataRequest, RequestKind, ResponseMessage},
node_implementation::{NodeImplementation, NodeType},
signature_key::SignatureKey,
},
vote::HasViewNumber,
};
use rand::{prelude::SliceRandom, thread_rng};
use sha2::{Digest, Sha256};
#[cfg(async_executor_impl = "tokio")]
use tokio::task::JoinHandle;
use tracing::{debug, error, info, instrument, warn};
use crate::{events::HotShotEvent, helpers::broadcast_event};
pub const REQUEST_TIMEOUT: Duration = Duration::from_millis(500);
pub struct NetworkRequestState<TYPES: NodeType, I: NodeImplementation<TYPES>> {
pub network: Arc<I::Network>,
pub state: OuterConsensus<TYPES>,
pub view: TYPES::Time,
pub delay: Duration,
pub da_membership: TYPES::Membership,
pub quorum_membership: TYPES::Membership,
pub public_key: TYPES::SignatureKey,
pub private_key: <TYPES::SignatureKey as SignatureKey>::PrivateKey,
pub id: u64,
pub shutdown_flag: Arc<AtomicBool>,
pub spawned_tasks: BTreeMap<TYPES::Time, Vec<JoinHandle<()>>>,
}
impl<TYPES: NodeType, I: NodeImplementation<TYPES>> Drop for NetworkRequestState<TYPES, I> {
fn drop(&mut self) {
futures::executor::block_on(async move { self.cancel_subtasks().await });
}
}
type Signature<TYPES> =
<<TYPES as NodeType>::SignatureKey as SignatureKey>::PureAssembledSignatureType;
#[async_trait]
impl<TYPES: NodeType, I: NodeImplementation<TYPES>> TaskState for NetworkRequestState<TYPES, I> {
type Event = HotShotEvent<TYPES>;
#[instrument(skip_all, target = "NetworkRequestState", fields(id = self.id))]
async fn handle_event(
&mut self,
event: Arc<Self::Event>,
sender: &Sender<Arc<Self::Event>>,
_receiver: &Receiver<Arc<Self::Event>>,
) -> Result<()> {
match event.as_ref() {
HotShotEvent::QuorumProposalValidated(proposal, _) => {
let prop_view = proposal.view_number();
if prop_view >= self.view {
self.spawn_requests(prop_view, sender.clone()).await;
}
Ok(())
}
HotShotEvent::ViewChange(view) => {
let view = *view;
if view > self.view {
self.view = view;
}
Ok(())
}
HotShotEvent::QuorumProposalRequestRecv(req, signature) => {
ensure!(
req.key.validate(signature, req.commit().as_ref()),
"Invalid signature key on proposal request."
);
if let Some(quorum_proposal) = self
.state
.read()
.await
.last_proposals()
.get(&req.view_number)
{
broadcast_event(
HotShotEvent::QuorumProposalResponseSend(
req.key.clone(),
quorum_proposal.clone(),
)
.into(),
sender,
)
.await;
}
Ok(())
}
_ => Ok(()),
}
}
async fn cancel_subtasks(&mut self) {
self.set_shutdown_flag();
while !self.spawned_tasks.is_empty() {
let Some((_, handles)) = self.spawned_tasks.pop_first() else {
break;
};
for handle in handles {
#[cfg(async_executor_impl = "async-std")]
handle.cancel().await;
#[cfg(async_executor_impl = "tokio")]
handle.abort();
}
}
}
}
impl<TYPES: NodeType, I: NodeImplementation<TYPES>> NetworkRequestState<TYPES, I> {
async fn spawn_requests(
&mut self,
view: TYPES::Time,
sender: Sender<Arc<HotShotEvent<TYPES>>>,
) {
let requests = self.build_requests(view).await;
if requests.is_empty() {
return;
}
requests
.into_iter()
.for_each(|r| self.run_delay(r, sender.clone(), view));
}
#[instrument(skip_all, target = "NetworkRequestState", fields(id = self.id, view = *view))]
async fn build_requests(&self, view: TYPES::Time) -> Vec<RequestKind<TYPES>> {
let mut reqs = Vec::new();
if !self.state.read().await.vid_shares().contains_key(&view) {
reqs.push(RequestKind::Vid(view, self.public_key.clone()));
}
reqs
}
fn serialize_and_sign(
&self,
request: &RequestKind<TYPES>,
) -> Option<<TYPES::SignatureKey as SignatureKey>::PureAssembledSignatureType> {
let Ok(data) = bincode::serialize(&request) else {
tracing::error!("Failed to serialize request!");
return None;
};
let Ok(signature) = TYPES::SignatureKey::sign(&self.private_key, &Sha256::digest(data))
else {
error!("Failed to sign Data Request");
return None;
};
Some(signature)
}
#[instrument(skip_all, fields(id = self.id, view = *self.view), name = "NetworkRequestState run_delay", level = "error")]
fn run_delay(
&mut self,
request: RequestKind<TYPES>,
sender: Sender<Arc<HotShotEvent<TYPES>>>,
view: TYPES::Time,
) {
let mut recipients: Vec<_> = self
.da_membership
.committee_members(view)
.into_iter()
.collect();
recipients.shuffle(&mut thread_rng());
let requester = DelayedRequester::<TYPES, I> {
network: Arc::clone(&self.network),
state: OuterConsensus::new(Arc::clone(&self.state.inner_consensus)),
public_key: self.public_key.clone(),
sender,
delay: self.delay,
recipients,
shutdown_flag: Arc::clone(&self.shutdown_flag),
id: self.id,
};
let Some(signature) = self.serialize_and_sign(&request) else {
return;
};
debug!("Requesting data: {:?}", request);
let handle = async_spawn(requester.run(request, signature));
self.spawned_tasks.entry(view).or_default().push(handle);
}
pub fn set_shutdown_flag(&self) {
self.shutdown_flag.store(true, Ordering::Relaxed);
}
}
struct DelayedRequester<TYPES: NodeType, I: NodeImplementation<TYPES>> {
pub network: Arc<I::Network>,
state: OuterConsensus<TYPES>,
public_key: TYPES::SignatureKey,
sender: Sender<Arc<HotShotEvent<TYPES>>>,
delay: Duration,
recipients: Vec<TYPES::SignatureKey>,
shutdown_flag: Arc<AtomicBool>,
id: u64,
}
struct VidRequest<TYPES: NodeType>(TYPES::Time, TYPES::SignatureKey);
impl<TYPES: NodeType, I: NodeImplementation<TYPES>> DelayedRequester<TYPES, I> {
async fn run(self, request: RequestKind<TYPES>, signature: Signature<TYPES>) {
match request {
RequestKind::Vid(view, key) => {
if !self.network.is_primary_down() {
async_sleep(self.delay).await;
}
self.do_vid(VidRequest(view, key), signature).await;
}
RequestKind::Proposal(..) | RequestKind::DaProposal(..) => {}
}
}
async fn do_vid(&self, req: VidRequest<TYPES>, signature: Signature<TYPES>) {
let message = make_vid(&req, signature);
let mut recipients_it = self.recipients.iter().cycle();
let serialized_msg = match bincode::serialize(&message) {
Ok(serialized_msg) => serialized_msg,
Err(e) => {
tracing::error!(
"Failed to serialize outgoing message: this should never happen. Error: {e}"
);
return;
}
};
while !self.cancel_vid(&req).await {
match async_timeout(
REQUEST_TIMEOUT,
self.network
.request_data::<TYPES>(serialized_msg.clone(), recipients_it.next().unwrap()),
)
.await
{
Ok(Ok(response)) => {
match bincode::deserialize(&response) {
Ok(ResponseMessage::Found(data)) => {
self.handle_response_message(data).await;
async_sleep(REQUEST_TIMEOUT).await;
}
Ok(ResponseMessage::NotFound) => {
info!("Peer Responded they did not have the data");
}
Ok(ResponseMessage::Denied) => {
error!("Request for data was denied by the receiver");
}
Err(e) => {
error!("Failed to deserialize response: {e}");
}
}
}
Ok(Err(e)) => {
warn!("Error Sending request. Error: {:?}", e);
async_sleep(REQUEST_TIMEOUT).await;
}
Err(_) => {
warn!("Request to other node timed out");
}
}
}
}
#[instrument(skip_all, target = "DelayedRequester", fields(id = self.id, view = *req.0))]
async fn cancel_vid(&self, req: &VidRequest<TYPES>) -> bool {
let view = req.0;
let state = self.state.read().await;
let cancel = self.shutdown_flag.load(Ordering::Relaxed)
|| state.vid_shares().contains_key(&view)
|| state.cur_view() > view;
if cancel {
if let Some(Some(vid_share)) = state
.vid_shares()
.get(&view)
.map(|shares| shares.get(&self.public_key).cloned())
{
broadcast_event(
Arc::new(HotShotEvent::VidShareRecv(vid_share.clone())),
&self.sender,
)
.await;
}
tracing::debug!(
"Canceling vid request for view {:?}, cur view is {:?}",
view,
state.cur_view()
);
}
cancel
}
async fn handle_response_message(&self, message: SequencingMessage<TYPES>) {
let event = match message {
SequencingMessage::Da(DaConsensusMessage::VidDisperseMsg(prop)) => {
tracing::info!("vid req complete, got vid {:?}", prop);
HotShotEvent::VidShareRecv(prop)
}
_ => return,
};
broadcast_event(Arc::new(event), &self.sender).await;
}
}
fn make_vid<TYPES: NodeType>(
req: &VidRequest<TYPES>,
signature: Signature<TYPES>,
) -> Message<TYPES> {
let kind = RequestKind::Vid(req.0, req.1.clone());
let data_request = DataRequest {
view: req.0,
request: kind,
signature,
};
Message {
sender: req.1.clone(),
kind: MessageKind::Data(DataMessage::RequestData(data_request)),
}
}