1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// Copyright (c) 2021-2024 Espresso Systems (espressosys.com)
// This file is part of the HotShot repository.

// You should have received a copy of the MIT License
// along with the HotShot repository. If not, see <https://mit-license.org/>.

use std::sync::Arc;

use async_trait::async_trait;
use hotshot_example_types::node_types::{MemoryImpl, TestTypes, TestVersions};
use hotshot_task_impls::quorum_vote::QuorumVoteTaskState;
use hotshot_types::simple_certificate::UpgradeCertificate;

use crate::predicates::{Predicate, PredicateResult};
type QuorumVoteTaskTestState = QuorumVoteTaskState<TestTypes, MemoryImpl, TestVersions>;

type UpgradeCertCallback =
    Arc<dyn Fn(Arc<Option<UpgradeCertificate<TestTypes>>>) -> bool + Send + Sync>;

pub struct UpgradeCertPredicate {
    check: UpgradeCertCallback,
    info: String,
}

impl std::fmt::Debug for UpgradeCertPredicate {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.info)
    }
}

#[async_trait]
impl Predicate<QuorumVoteTaskTestState> for UpgradeCertPredicate {
    async fn evaluate(&self, input: &QuorumVoteTaskTestState) -> PredicateResult {
        let upgrade_cert = input
            .upgrade_lock
            .decided_upgrade_certificate
            .read()
            .await
            .clone();
        PredicateResult::from((self.check)(upgrade_cert.into()))
    }

    async fn info(&self) -> String {
        self.info.clone()
    }
}

pub fn no_decided_upgrade_certificate() -> Box<UpgradeCertPredicate> {
    let info = "expected decided_upgrade_certificate to be None".to_string();
    let check: UpgradeCertCallback = Arc::new(move |s| s.is_none());
    Box::new(UpgradeCertPredicate { info, check })
}

pub fn decided_upgrade_certificate() -> Box<UpgradeCertPredicate> {
    let info = "expected decided_upgrade_certificate to be Some(_)".to_string();
    let check: UpgradeCertCallback = Arc::new(move |s| s.is_some());
    Box::new(UpgradeCertPredicate { info, check })
}