mirror of
https://github.com/chatmail/core.git
synced 2026-05-05 06:16:30 +03:00
fix: prevent reuse of the stream after an error
When a stream timeouts, `tokio_io_timeout::TimeoutStream` returns an error once, but then allows to keep using the stream, e.g. calling `poll_read()` again. This can be dangerous if the error is ignored. For example in case of IMAP stream, if IMAP command is sent, but then reading the response times out and the error is ignored, it is possible to send another IMAP command. In this case leftover response from a previous command may be read and interpreted as the response to the new IMAP command. ErrorCapturingStream wraps the stream to prevent its reuse after an error.
This commit is contained in:
10
src/net.rs
10
src/net.rs
@@ -16,12 +16,14 @@ use crate::sql::Sql;
|
||||
use crate::tools::time;
|
||||
|
||||
pub(crate) mod dns;
|
||||
pub(crate) mod error_capturing_stream;
|
||||
pub(crate) mod http;
|
||||
pub(crate) mod proxy;
|
||||
pub(crate) mod session;
|
||||
pub(crate) mod tls;
|
||||
|
||||
use dns::lookup_host_with_cache;
|
||||
pub(crate) use error_capturing_stream::ErrorCapturingStream;
|
||||
pub use http::{Response as HttpResponse, read_url, read_url_blob};
|
||||
use tls::wrap_tls;
|
||||
|
||||
@@ -105,7 +107,7 @@ pub(crate) async fn load_connection_timestamp(
|
||||
/// to the network, which is important to reduce the latency of interactive protocols such as IMAP.
|
||||
pub(crate) async fn connect_tcp_inner(
|
||||
addr: SocketAddr,
|
||||
) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
|
||||
) -> Result<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>> {
|
||||
let tcp_stream = timeout(TIMEOUT, TcpStream::connect(addr))
|
||||
.await
|
||||
.context("connection timeout")?
|
||||
@@ -118,7 +120,9 @@ pub(crate) async fn connect_tcp_inner(
|
||||
timeout_stream.set_write_timeout(Some(TIMEOUT));
|
||||
timeout_stream.set_read_timeout(Some(TIMEOUT));
|
||||
|
||||
Ok(Box::pin(timeout_stream))
|
||||
let error_capturing_stream = ErrorCapturingStream::new(timeout_stream);
|
||||
|
||||
Ok(Box::pin(error_capturing_stream))
|
||||
}
|
||||
|
||||
/// Attempts to establish TLS connection
|
||||
@@ -235,7 +239,7 @@ pub(crate) async fn connect_tcp(
|
||||
host: &str,
|
||||
port: u16,
|
||||
load_cache: bool,
|
||||
) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
|
||||
) -> Result<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>> {
|
||||
let connection_futures = lookup_host_with_cache(context, host, port, "", load_cache)
|
||||
.await?
|
||||
.into_iter()
|
||||
|
||||
136
src/net/error_capturing_stream.rs
Normal file
136
src/net/error_capturing_stream.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
use std::io::IoSlice;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use pin_project::pin_project;
|
||||
|
||||
use crate::net::SessionStream;
|
||||
|
||||
/// Stream that remembers the first error
|
||||
/// and keeps returning it afterwards.
|
||||
///
|
||||
/// It is needed to avoid accidentally using
|
||||
/// the stream after read timeout.
|
||||
#[derive(Debug)]
|
||||
#[pin_project]
|
||||
pub(crate) struct ErrorCapturingStream<T: AsyncRead + AsyncWrite + std::fmt::Debug> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
|
||||
/// If true, the stream has already returned an error once.
|
||||
///
|
||||
/// All read and write operations return error in this case.
|
||||
is_broken: bool,
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> ErrorCapturingStream<T> {
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
is_broken: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets a reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &T {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
/// Gets a pinned mutable reference to the underlying stream.
|
||||
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
|
||||
self.project().inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> AsyncRead for ErrorCapturingStream<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
if *this.is_broken {
|
||||
return Poll::Ready(Err(io::Error::other("Broken stream")));
|
||||
}
|
||||
let res = this.inner.poll_read(cx, buf);
|
||||
if let Poll::Ready(Err(_)) = res {
|
||||
*this.is_broken = true;
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> AsyncWrite for ErrorCapturingStream<T> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
if *this.is_broken {
|
||||
return Poll::Ready(Err(io::Error::other("Broken stream")));
|
||||
}
|
||||
let res = this.inner.poll_write(cx, buf);
|
||||
if let Poll::Ready(Err(_)) = res {
|
||||
*this.is_broken = true;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
if *this.is_broken {
|
||||
return Poll::Ready(Err(io::Error::other("Broken stream")));
|
||||
}
|
||||
let res = this.inner.poll_flush(cx);
|
||||
if let Poll::Ready(Err(_)) = res {
|
||||
*this.is_broken = true;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
if *this.is_broken {
|
||||
return Poll::Ready(Err(io::Error::other("Broken stream")));
|
||||
}
|
||||
let res = this.inner.poll_shutdown(cx);
|
||||
if let Poll::Ready(Err(_)) = res {
|
||||
*this.is_broken = true;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
if *this.is_broken {
|
||||
return Poll::Ready(Err(io::Error::other("Broken stream")));
|
||||
}
|
||||
let res = this.inner.poll_write_vectored(cx, bufs);
|
||||
if let Poll::Ready(Err(_)) = res {
|
||||
*this.is_broken = true;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: SessionStream> SessionStream for ErrorCapturingStream<T> {
|
||||
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
|
||||
self.inner.set_read_timeout(timeout)
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> anyhow::Result<SocketAddr> {
|
||||
self.inner.peer_addr()
|
||||
}
|
||||
}
|
||||
@@ -21,9 +21,9 @@ use url::Url;
|
||||
use crate::config::Config;
|
||||
use crate::constants::NON_ALPHANUMERIC_WITHOUT_DOT;
|
||||
use crate::context::Context;
|
||||
use crate::net::connect_tcp;
|
||||
use crate::net::session::SessionStream;
|
||||
use crate::net::tls::wrap_rustls;
|
||||
use crate::net::{ErrorCapturingStream, connect_tcp};
|
||||
use crate::sql::Sql;
|
||||
|
||||
/// Default SOCKS5 port according to [RFC 1928](https://tools.ietf.org/html/rfc1928).
|
||||
@@ -118,7 +118,7 @@ impl Socks5Config {
|
||||
target_host: &str,
|
||||
target_port: u16,
|
||||
load_dns_cache: bool,
|
||||
) -> Result<Socks5Stream<Pin<Box<TimeoutStream<TcpStream>>>>> {
|
||||
) -> Result<Socks5Stream<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>>> {
|
||||
let tcp_stream = connect_tcp(context, &self.host, self.port, load_dns_cache)
|
||||
.await
|
||||
.context("Failed to connect to SOCKS5 proxy")?;
|
||||
|
||||
@@ -7,6 +7,8 @@ use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_io_timeout::TimeoutStream;
|
||||
|
||||
use crate::net::ErrorCapturingStream;
|
||||
|
||||
pub(crate) trait SessionStream:
|
||||
AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug
|
||||
{
|
||||
@@ -61,13 +63,13 @@ impl<T: SessionStream> SessionStream for BufWriter<T> {
|
||||
self.get_ref().peer_addr()
|
||||
}
|
||||
}
|
||||
impl SessionStream for Pin<Box<TimeoutStream<TcpStream>>> {
|
||||
impl SessionStream for Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>> {
|
||||
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
|
||||
self.as_mut().set_read_timeout_pinned(timeout);
|
||||
self.as_mut().get_pin_mut().set_read_timeout_pinned(timeout);
|
||||
}
|
||||
|
||||
fn peer_addr(&self) -> Result<SocketAddr> {
|
||||
Ok(self.get_ref().peer_addr()?)
|
||||
Ok(self.get_ref().get_ref().peer_addr()?)
|
||||
}
|
||||
}
|
||||
impl<T: SessionStream> SessionStream for Socks5Stream<T> {
|
||||
|
||||
Reference in New Issue
Block a user