Compare commits

...

1 Commits

Author SHA1 Message Date
link2xt
3010d28901 fix: prevent reuse of the stream after an error
When a stream timeouts, `tokio_io_timeout::TimeoutStream`
returns an error once, but then allows to keep using
the stream, e.g. calling `poll_read()` again.

This can be dangerous if the error is ignored.
For example in case of IMAP stream,
if IMAP command is sent,
but then reading the response
times out and the error is ignored,
it is possible to send another IMAP command.
In this case leftover response
from a previous command may be read
and interpreted as the response
to the new IMAP command.

ErrorCapturingStream wraps the stream
to prevent its reuse after an error.
2025-07-19 13:44:00 +00:00
4 changed files with 150 additions and 8 deletions

View File

@@ -16,12 +16,14 @@ use crate::sql::Sql;
use crate::tools::time;
pub(crate) mod dns;
pub(crate) mod error_capturing_stream;
pub(crate) mod http;
pub(crate) mod proxy;
pub(crate) mod session;
pub(crate) mod tls;
use dns::lookup_host_with_cache;
pub(crate) use error_capturing_stream::ErrorCapturingStream;
pub use http::{Response as HttpResponse, read_url, read_url_blob};
use tls::wrap_tls;
@@ -105,7 +107,7 @@ pub(crate) async fn load_connection_timestamp(
/// to the network, which is important to reduce the latency of interactive protocols such as IMAP.
pub(crate) async fn connect_tcp_inner(
addr: SocketAddr,
) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
) -> Result<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>> {
let tcp_stream = timeout(TIMEOUT, TcpStream::connect(addr))
.await
.context("connection timeout")?
@@ -118,7 +120,9 @@ pub(crate) async fn connect_tcp_inner(
timeout_stream.set_write_timeout(Some(TIMEOUT));
timeout_stream.set_read_timeout(Some(TIMEOUT));
Ok(Box::pin(timeout_stream))
let error_capturing_stream = ErrorCapturingStream::new(timeout_stream);
Ok(Box::pin(error_capturing_stream))
}
/// Attempts to establish TLS connection
@@ -235,7 +239,7 @@ pub(crate) async fn connect_tcp(
host: &str,
port: u16,
load_cache: bool,
) -> Result<Pin<Box<TimeoutStream<TcpStream>>>> {
) -> Result<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>> {
let connection_futures = lookup_host_with_cache(context, host, port, "", load_cache)
.await?
.into_iter()

View File

@@ -0,0 +1,136 @@
use std::io::IoSlice;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf};
use pin_project::pin_project;
use crate::net::SessionStream;
/// Stream that remembers the first error
/// and keeps returning it afterwards.
///
/// It is needed to avoid accidentally using
/// the stream after read timeout.
#[derive(Debug)]
#[pin_project]
pub(crate) struct ErrorCapturingStream<T: AsyncRead + AsyncWrite + std::fmt::Debug> {
#[pin]
inner: T,
/// If true, the stream has already returned an error once.
///
/// All read and write operations return error in this case.
is_broken: bool,
}
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> ErrorCapturingStream<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
is_broken: false,
}
}
/// Gets a reference to the underlying stream.
pub fn get_ref(&self) -> &T {
&self.inner
}
/// Gets a pinned mutable reference to the underlying stream.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().inner
}
}
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> AsyncRead for ErrorCapturingStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_read(cx, buf);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}
}
impl<T: AsyncRead + AsyncWrite + std::fmt::Debug> AsyncWrite for ErrorCapturingStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_write(cx, buf);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_flush(cx);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_shutdown(cx);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.project();
if *this.is_broken {
return Poll::Ready(Err(io::Error::other("Broken stream")));
}
let res = this.inner.poll_write_vectored(cx, bufs);
if let Poll::Ready(Err(_)) = res {
*this.is_broken = true;
}
res
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl<T: SessionStream> SessionStream for ErrorCapturingStream<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.inner.set_read_timeout(timeout)
}
fn peer_addr(&self) -> anyhow::Result<SocketAddr> {
self.inner.peer_addr()
}
}

View File

@@ -21,9 +21,9 @@ use url::Url;
use crate::config::Config;
use crate::constants::NON_ALPHANUMERIC_WITHOUT_DOT;
use crate::context::Context;
use crate::net::connect_tcp;
use crate::net::session::SessionStream;
use crate::net::tls::wrap_rustls;
use crate::net::{ErrorCapturingStream, connect_tcp};
use crate::sql::Sql;
/// Default SOCKS5 port according to [RFC 1928](https://tools.ietf.org/html/rfc1928).
@@ -118,7 +118,7 @@ impl Socks5Config {
target_host: &str,
target_port: u16,
load_dns_cache: bool,
) -> Result<Socks5Stream<Pin<Box<TimeoutStream<TcpStream>>>>> {
) -> Result<Socks5Stream<Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>>>> {
let tcp_stream = connect_tcp(context, &self.host, self.port, load_dns_cache)
.await
.context("Failed to connect to SOCKS5 proxy")?;

View File

@@ -7,6 +7,8 @@ use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter};
use tokio::net::TcpStream;
use tokio_io_timeout::TimeoutStream;
use crate::net::ErrorCapturingStream;
pub(crate) trait SessionStream:
AsyncRead + AsyncWrite + Unpin + Send + Sync + std::fmt::Debug
{
@@ -61,13 +63,13 @@ impl<T: SessionStream> SessionStream for BufWriter<T> {
self.get_ref().peer_addr()
}
}
impl SessionStream for Pin<Box<TimeoutStream<TcpStream>>> {
impl SessionStream for Pin<Box<ErrorCapturingStream<TimeoutStream<TcpStream>>>> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.as_mut().set_read_timeout_pinned(timeout);
self.as_mut().get_pin_mut().set_read_timeout_pinned(timeout);
}
fn peer_addr(&self) -> Result<SocketAddr> {
Ok(self.get_ref().peer_addr()?)
Ok(self.get_ref().get_ref().peer_addr()?)
}
}
impl<T: SessionStream> SessionStream for Socks5Stream<T> {