use std::cmp;
use std::fmt::Debug;
use std::future::Future;
use std::io::{Error as IoError, SeekFrom};
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::{Buf, BufMut};
use err_context::AnyError;
use log::trace;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use spirit::extension::Extensible;
use spirit::fragment::driver::{CacheSimilar, Comparable, Comparison};
use spirit::fragment::{Fragment, Stackable};
#[cfg(feature = "cfg-help")]
use structdoc::StructDoc;
use structopt::StructOpt;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite};
#[cfg(feature = "stream")]
use tokio::stream::Stream;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::time::{self, Delay};
use super::Accept;
pub trait ListenLimits {
fn error_sleep(&self) -> Duration;
fn max_conn(&self) -> usize;
}
#[derive(
Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, StructDoc,
)]
#[non_exhaustive]
pub struct WithListenLimits<A, L> {
#[serde(flatten)]
pub listen: A,
#[serde(flatten)]
pub limits: L,
}
pub type WithLimits<A> = WithListenLimits<A, Limits>;
impl<A, L> Stackable for WithListenLimits<A, L> where A: Stackable {}
impl<A, L> Comparable for WithListenLimits<A, L>
where
A: Comparable,
L: PartialEq,
{
fn compare(&self, other: &Self) -> Comparison {
let listener_cmp = self.listen.compare(&other.listen);
if listener_cmp == Comparison::Same && self.limits != other.limits {
Comparison::Similar
} else {
listener_cmp
}
}
}
impl<A, L> Fragment for WithListenLimits<A, L>
where
A: Clone + Debug + Fragment + Comparable,
L: Clone + Debug + ListenLimits + PartialEq,
{
type Driver = CacheSimilar<Self>;
type Installer = ();
type Seed = A::Seed;
type Resource = Limited<A::Resource>;
const RUN_BEFORE_CONFIG: bool = A::RUN_BEFORE_CONFIG;
fn make_seed(&self, name: &'static str) -> Result<Self::Seed, AnyError> {
self.listen.make_seed(name)
}
fn make_resource(
&self,
seed: &mut Self::Seed,
name: &'static str,
) -> Result<Self::Resource, AnyError> {
let inner = self.listen.make_resource(seed, name)?;
let limit = cmp::min(self.limits.max_conn(), usize::MAX >> 4);
Ok(Limited {
inner,
error_sleep: self.limits.error_sleep(),
err_delay: None,
allowed_conns: Arc::new(Semaphore::new(limit)),
permit_fut: None,
})
}
fn init<B: Extensible<Ok = B>>(builder: B, name: &'static str) -> Result<B, AnyError>
where
B::Config: DeserializeOwned + Send + Sync + 'static,
B::Opts: StructOpt + Send + Sync + 'static,
{
A::init(builder, name)
}
}
fn default_error_sleep() -> Duration {
Duration::from_millis(100)
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize)]
#[cfg_attr(feature = "cfg-help", derive(StructDoc))]
#[non_exhaustive]
pub struct Limits {
#[serde(
rename = "error-sleep",
default = "default_error_sleep",
deserialize_with = "spirit::utils::deserialize_duration",
serialize_with = "spirit::utils::serialize_duration"
)]
pub error_sleep: Duration,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_conn: Option<usize>,
}
impl Default for Limits {
fn default() -> Self {
Self {
error_sleep: default_error_sleep(),
max_conn: None,
}
}
}
impl ListenLimits for Limits {
fn error_sleep(&self) -> Duration {
self.error_sleep
}
fn max_conn(&self) -> usize {
self.max_conn.unwrap_or_else(|| usize::max_value() / 2 - 1)
}
}
pub struct Limited<A> {
inner: A,
error_sleep: Duration,
err_delay: Option<Delay>,
allowed_conns: Arc<Semaphore>,
permit_fut: Option<Pin<Box<dyn Future<Output = OwnedSemaphorePermit> + Send + Sync>>>,
}
impl<A: Accept> Accept for Limited<A> {
type Connection = Tracked<A::Connection>;
fn poll_accept(&mut self, ctx: &mut Context) -> Poll<Result<Self::Connection, IoError>> {
let permit = loop {
match self.permit_fut.as_mut() {
Some(fut) => match fut.as_mut().poll(ctx) {
Poll::Ready(permit) => break permit,
Poll::Pending => return Poll::Pending,
},
None => {
match Arc::clone(&self.allowed_conns).try_acquire_owned() {
Ok(permit) => break permit,
Err(_) => {
let permit_fut = Arc::clone(&self.allowed_conns).acquire_owned();
self.permit_fut = Some(Box::pin(permit_fut));
}
}
}
}
};
loop {
if let Some(delay) = self.err_delay.as_mut() {
if Pin::new(delay).poll(ctx).is_ready() {
self.err_delay.take();
} else {
return Poll::Pending;
}
}
fn is_connection_local(e: &IoError) -> bool {
use std::io::ErrorKind::*;
matches!(
e.kind(),
ConnectionAborted | ConnectionRefused | ConnectionReset
)
}
match self.inner.poll_accept(ctx) {
Poll::Ready(Err(ref e)) if is_connection_local(e) => {
trace!("Connection attempt error: {}", e);
continue;
}
Poll::Ready(Err(_)) => {
trace!("Accept error, sleeping for {:?}", self.error_sleep);
self.err_delay = Some(time::delay_for(self.error_sleep));
}
Poll::Ready(Ok(conn)) => {
trace!("Got a new connection");
return Poll::Ready(Ok(Tracked {
inner: conn,
_permit: permit,
}));
}
Poll::Pending => return Poll::Pending,
}
}
}
}
#[cfg(feature = "stream")]
impl<A> Stream for Limited<A>
where
A: Accept + Unpin,
{
type Item = Result<<Self as Accept>::Connection, IoError>;
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Option<Self::Item>> {
self.poll_accept(ctx).map(Some)
}
}
pub struct Tracked<C> {
inner: C,
_permit: OwnedSemaphorePermit,
}
impl<C> Deref for Tracked<C> {
type Target = C;
fn deref(&self) -> &C {
&self.inner
}
}
impl<C> DerefMut for Tracked<C> {
fn deref_mut(&mut self) -> &mut C {
&mut self.inner
}
}
impl<C: AsyncBufRead + Unpin> AsyncBufRead for Tracked<C> {
fn poll_fill_buf(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Result<&[u8], IoError>> {
Pin::new(&mut self.get_mut().inner).poll_fill_buf(ctx)
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut self.inner).consume(amt)
}
}
impl<C: AsyncRead + Unpin> AsyncRead for Tracked<C> {
fn poll_read(
mut self: Pin<&mut Self>,
ctx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.inner).poll_read(ctx, buf)
}
fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
ctx: &mut Context,
buf: &mut B,
) -> Poll<Result<usize, IoError>>
where
Self: Sized,
{
Pin::new(&mut self.inner).poll_read_buf(ctx, buf)
}
}
impl<C: AsyncSeek + Unpin> AsyncSeek for Tracked<C> {
fn start_seek(
mut self: Pin<&mut Self>,
ctx: &mut Context,
position: SeekFrom,
) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.inner).start_seek(ctx, position)
}
fn poll_complete(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result<u64, IoError>> {
Pin::new(&mut self.inner).poll_complete(ctx)
}
}
impl<C: AsyncWrite + Unpin> AsyncWrite for Tracked<C> {
fn poll_write(
mut self: Pin<&mut Self>,
ctx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
Pin::new(&mut self.inner).poll_write(ctx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.inner).poll_flush(ctx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Result<(), IoError>> {
Pin::new(&mut self.inner).poll_shutdown(ctx)
}
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut B,
) -> Poll<Result<usize, IoError>>
where
Self: Sized,
{
Pin::new(&mut self.inner).poll_write_buf(ctx, buf)
}
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use spirit::Empty;
use tokio::net::TcpStream;
use tokio::sync::oneshot::{self, Receiver};
use tokio::time;
use super::*;
use crate::net::{Listen, TcpListen};
fn listener() -> (Limited<impl Accept>, SocketAddr) {
let incoming_cfg = WithListenLimits {
listen: TcpListen {
listen: Listen {
host: IpAddr::V4(Ipv4Addr::LOCALHOST),
..Listen::default()
},
tcp_config: Empty {},
extra_cfg: Empty {},
},
limits: Limits {
error_sleep: Duration::from_millis(100),
max_conn: Some(2),
},
};
let mut seed = incoming_cfg.make_seed("test_listener").unwrap();
let addr = seed.local_addr().unwrap();
let listener = incoming_cfg
.make_resource(&mut seed, "test_listener")
.unwrap();
assert_eq!(2, listener.allowed_conns.available_permits());
(listener, addr)
}
async fn connector(addr: SocketAddr, done: Receiver<()>) {
let _conn1 = TcpStream::connect(addr).await.unwrap();
let _conn2 = TcpStream::connect(addr).await.unwrap();
let _conn3 = TcpStream::connect(addr).await.unwrap();
done.await.unwrap();
}
#[tokio::test]
async fn conn_limit() {
let (mut listener, addr) = listener();
let (done_send, done_recv) = oneshot::channel();
let connector = tokio::spawn(connector(addr, done_recv));
let acceptor = tokio::spawn(async move {
let conn1 = listener.accept().await.unwrap();
let _conn2 = listener.accept().await.unwrap();
assert_eq!(0, listener.allowed_conns.available_permits());
let over_limit = listener.accept();
let over_limit = time::timeout(Duration::from_millis(50), over_limit);
assert!(over_limit.await.is_err(), "Accepted extra connection");
drop(conn1);
let _conn3 = listener.accept().await.unwrap();
done_send.send(()).unwrap();
});
time::timeout(Duration::from_secs(5), async {
acceptor.await.unwrap();
connector.await.unwrap();
})
.await
.expect("Didn't finish test in time");
}
#[tokio::test]
async fn conn_limit_cont() {
let (mut listener, addr) = listener();
let (done_send, done_recv) = oneshot::channel();
let connector = tokio::spawn(connector(addr, done_recv));
let acceptor = tokio::spawn(async move {
for _ in 0..3 {
let conn = listener.accept().await.unwrap();
tokio::spawn(async move {
time::delay_for(Duration::from_millis(50)).await;
drop(conn);
});
}
done_send.send(()).unwrap();
});
time::timeout(Duration::from_secs(5), async {
acceptor.await.unwrap();
connector.await.unwrap();
})
.await
.expect("Didn't finish test in time");
}
}