use std::{collections::TryReserveError, convert::Infallible, io, marker::PhantomData};
use async_trait::async_trait;
use cbor4ii::core::error::DecodeError;
use futures::prelude::*;
use libp2p::{
request_response::{self, Codec},
StreamProtocol,
};
use serde::{de::DeserializeOwned, Serialize};
pub type Behaviour<Req, Resp> = request_response::Behaviour<Cbor<Req, Resp>>;
pub struct Cbor<Req, Resp> {
phantom: PhantomData<(Req, Resp)>,
request_size_maximum: u64,
response_size_maximum: u64,
}
impl<Req, Resp> Default for Cbor<Req, Resp> {
fn default() -> Self {
Cbor {
phantom: PhantomData,
request_size_maximum: 20 * 1024 * 1024,
response_size_maximum: 20 * 1024 * 1024,
}
}
}
impl<Req, Resp> Cbor<Req, Resp> {
#[must_use]
pub fn new(request_size_maximum: u64, response_size_maximum: u64) -> Self {
Cbor {
phantom: PhantomData,
request_size_maximum,
response_size_maximum,
}
}
}
impl<Req, Resp> Clone for Cbor<Req, Resp> {
fn clone(&self) -> Self {
Self::default()
}
}
#[async_trait]
impl<Req, Resp> Codec for Cbor<Req, Resp>
where
Req: Send + Serialize + DeserializeOwned,
Resp: Send + Serialize + DeserializeOwned,
{
type Protocol = StreamProtocol;
type Request = Req;
type Response = Resp;
async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Req>
where
T: AsyncRead + Unpin + Send,
{
let mut vec = Vec::new();
io.take(self.request_size_maximum)
.read_to_end(&mut vec)
.await?;
cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
}
async fn read_response<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Resp>
where
T: AsyncRead + Unpin + Send,
{
let mut vec = Vec::new();
io.take(self.response_size_maximum)
.read_to_end(&mut vec)
.await?;
cbor4ii::serde::from_slice(vec.as_slice()).map_err(decode_into_io_error)
}
async fn write_request<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
req: Self::Request,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
let data: Vec<u8> =
cbor4ii::serde::to_vec(Vec::new(), &req).map_err(encode_into_io_error)?;
io.write_all(data.as_ref()).await?;
Ok(())
}
async fn write_response<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
resp: Self::Response,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
let data: Vec<u8> =
cbor4ii::serde::to_vec(Vec::new(), &resp).map_err(encode_into_io_error)?;
io.write_all(data.as_ref()).await?;
Ok(())
}
}
fn decode_into_io_error(err: cbor4ii::serde::DecodeError<Infallible>) -> io::Error {
match err {
cbor4ii::serde::DecodeError::Core(DecodeError::Read(e)) => {
io::Error::new(io::ErrorKind::Other, e)
}
cbor4ii::serde::DecodeError::Core(e @ DecodeError::Unsupported { .. }) => {
io::Error::new(io::ErrorKind::Unsupported, e)
}
cbor4ii::serde::DecodeError::Core(e @ DecodeError::Eof { .. }) => {
io::Error::new(io::ErrorKind::UnexpectedEof, e)
}
cbor4ii::serde::DecodeError::Core(e) => io::Error::new(io::ErrorKind::InvalidData, e),
cbor4ii::serde::DecodeError::Custom(e) => {
io::Error::new(io::ErrorKind::Other, e.to_string())
}
}
}
fn encode_into_io_error(err: cbor4ii::serde::EncodeError<TryReserveError>) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}