feat: log the number of read/written bytes on IMAP stream read error (#6924)

This commit is contained in:
l
2025-07-17 20:01:16 +00:00
committed by GitHub
parent 6df1d165dd
commit a2df29515a
4 changed files with 237 additions and 16 deletions

View File

@@ -8,15 +8,13 @@ use tokio::io::BufWriter;
use super::capabilities::Capabilities;
use crate::context::Context;
use crate::log::{info, warn};
use crate::log::{LoggingStream, info, warn};
use crate::login_param::{ConnectionCandidate, ConnectionSecurity};
use crate::net::dns::{lookup_host_with_cache, update_connect_timestamp};
use crate::net::proxy::ProxyConfig;
use crate::net::session::SessionStream;
use crate::net::tls::wrap_tls;
use crate::net::{
connect_tcp_inner, connect_tls_inner, run_connection_attempts, update_connection_history,
};
use crate::net::{connect_tcp_inner, run_connection_attempts, update_connection_history};
use crate::tools::time;
#[derive(Debug)]
@@ -126,12 +124,12 @@ impl Client {
);
let res = match security {
ConnectionSecurity::Tls => {
Client::connect_secure(resolved_addr, host, strict_tls).await
Client::connect_secure(context, resolved_addr, host, strict_tls).await
}
ConnectionSecurity::Starttls => {
Client::connect_starttls(resolved_addr, host, strict_tls).await
Client::connect_starttls(context, resolved_addr, host, strict_tls).await
}
ConnectionSecurity::Plain => Client::connect_insecure(resolved_addr).await,
ConnectionSecurity::Plain => Client::connect_insecure(context, resolved_addr).await,
};
match res {
Ok(client) => {
@@ -202,8 +200,17 @@ impl Client {
}
}
async fn connect_secure(addr: SocketAddr, hostname: &str, strict_tls: bool) -> Result<Self> {
let tls_stream = connect_tls_inner(addr, hostname, strict_tls, alpn(addr.port())).await?;
async fn connect_secure(
context: &Context,
addr: SocketAddr,
hostname: &str,
strict_tls: bool,
) -> Result<Self> {
let tcp_stream = connect_tcp_inner(addr).await?;
let account_id = context.get_id();
let events = context.events.clone();
let logging_stream = LoggingStream::new(tcp_stream, account_id, events);
let tls_stream = wrap_tls(strict_tls, hostname, alpn(addr.port()), logging_stream).await?;
let buffered_stream = BufWriter::new(tls_stream);
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
let mut client = Client::new(session_stream);
@@ -214,9 +221,12 @@ impl Client {
Ok(client)
}
async fn connect_insecure(addr: SocketAddr) -> Result<Self> {
async fn connect_insecure(context: &Context, addr: SocketAddr) -> Result<Self> {
let tcp_stream = connect_tcp_inner(addr).await?;
let buffered_stream = BufWriter::new(tcp_stream);
let account_id = context.get_id();
let events = context.events.clone();
let logging_stream = LoggingStream::new(tcp_stream, account_id, events);
let buffered_stream = BufWriter::new(logging_stream);
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
let mut client = Client::new(session_stream);
let _greeting = client
@@ -226,9 +236,18 @@ impl Client {
Ok(client)
}
async fn connect_starttls(addr: SocketAddr, host: &str, strict_tls: bool) -> Result<Self> {
async fn connect_starttls(
context: &Context,
addr: SocketAddr,
host: &str,
strict_tls: bool,
) -> Result<Self> {
let tcp_stream = connect_tcp_inner(addr).await?;
let account_id = context.get_id();
let events = context.events.clone();
let tcp_stream = LoggingStream::new(tcp_stream, account_id, events);
// Run STARTTLS command and convert the client back into a stream.
let buffered_tcp_stream = BufWriter::new(tcp_stream);
let mut client = async_imap::Client::new(buffered_tcp_stream);
@@ -246,7 +265,6 @@ impl Client {
let tls_stream = wrap_tls(strict_tls, host, &[], tcp_stream)
.await
.context("STARTTLS upgrade failed")?;
let buffered_stream = BufWriter::new(tls_stream);
let session_stream: Box<dyn SessionStream> = Box::new(buffered_stream);
let client = Client::new(session_stream);

View File

@@ -4,6 +4,10 @@
use crate::context::Context;
mod stream;
pub(crate) use stream::LoggingStream;
macro_rules! info {
($ctx:expr, $msg:expr) => {
info!($ctx, $msg,)

160
src/log/stream.rs Normal file
View File

@@ -0,0 +1,160 @@
//! Stream that logs errors as events.
//!
//! This stream can be used to wrap IMAP,
//! SMTP and HTTP streams so errors
//! that occur are logged before
//! they are processed.
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use anyhow::Result;
use pin_project::pin_project;
use crate::events::{Event, EventType, Events};
use crate::net::session::SessionStream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[derive(Debug)]
struct Metrics {
/// Total number of bytes read.
pub total_read: usize,
/// Total number of bytes written.
pub total_written: usize,
}
impl Metrics {
fn new() -> Self {
Self {
total_read: 0,
total_written: 0,
}
}
}
/// Stream that logs errors to the event channel.
#[derive(Debug)]
#[pin_project]
pub(crate) struct LoggingStream<S: SessionStream> {
#[pin]
inner: S,
/// Account ID for logging.
account_id: u32,
/// Event channel.
events: Events,
/// Metrics for this stream.
metrics: Metrics,
}
impl<S: SessionStream> LoggingStream<S> {
pub fn new(inner: S, account_id: u32, events: Events) -> Self {
Self {
inner,
account_id,
events,
metrics: Metrics::new(),
}
}
}
impl<S: SessionStream> AsyncRead for LoggingStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.project();
let peer_addr = this.inner.peer_addr();
let old_remaining = buf.remaining();
let res = this.inner.poll_read(cx, buf);
if let Poll::Ready(Err(ref err)) = res {
debug_assert!(
peer_addr.is_ok(),
"Logging stream should be created over a bound socket"
);
let log_message = match peer_addr {
Ok(peer_addr) => format!(
"Read error on stream {peer_addr:?} after reading {} and writing {} bytes: {err}.",
this.metrics.total_read, this.metrics.total_written
),
Err(_) => {
format!("Read error on a stream that does not have a peer address: {err}.")
}
};
this.events.emit(Event {
id: *this.account_id,
typ: EventType::Warning(log_message),
});
}
let n = old_remaining - buf.remaining();
this.metrics.total_read = this.metrics.total_read.saturating_add(n);
res
}
}
impl<S: SessionStream> AsyncWrite for LoggingStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let this = self.project();
let res = this.inner.poll_write(cx, buf);
if let Poll::Ready(Ok(n)) = res {
this.metrics.total_written = this.metrics.total_written.saturating_add(n);
}
res
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<()>> {
self.project().inner.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<std::io::Result<usize>> {
let this = self.project();
let res = this.inner.poll_write_vectored(cx, bufs);
if let Poll::Ready(Ok(n)) = res {
this.metrics.total_written = this.metrics.total_written.saturating_add(n);
}
res
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl<S: SessionStream> SessionStream for LoggingStream<S> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.inner.set_read_timeout(timeout)
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.inner.peer_addr()
}
}

View File

@@ -1,7 +1,10 @@
use anyhow::Result;
use fast_socks5::client::Socks5Stream;
use std::net::SocketAddr;
use std::pin::Pin;
use std::time::Duration;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, BufStream, BufWriter};
use tokio::net::TcpStream;
use tokio_io_timeout::TimeoutStream;
pub(crate) trait SessionStream:
@@ -9,54 +12,90 @@ pub(crate) trait SessionStream:
{
/// Change the read timeout on the session stream.
fn set_read_timeout(&mut self, timeout: Option<Duration>);
/// Returns the remote address that this stream is connected to.
fn peer_addr(&self) -> Result<SocketAddr>;
}
impl SessionStream for Box<dyn SessionStream> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.as_mut().set_read_timeout(timeout);
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.as_ref().peer_addr()
}
}
impl<T: SessionStream> SessionStream for async_native_tls::TlsStream<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.get_mut().set_read_timeout(timeout);
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.get_ref().peer_addr()
}
}
impl<T: SessionStream> SessionStream for tokio_rustls::client::TlsStream<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.get_mut().0.set_read_timeout(timeout);
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.get_ref().0.peer_addr()
}
}
impl<T: SessionStream> SessionStream for BufStream<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.get_mut().set_read_timeout(timeout);
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.get_ref().peer_addr()
}
}
impl<T: SessionStream> SessionStream for BufWriter<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.get_mut().set_read_timeout(timeout);
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.get_ref().peer_addr()
}
}
impl<T: AsyncRead + AsyncWrite + Send + Sync + std::fmt::Debug> SessionStream
for Pin<Box<TimeoutStream<T>>>
{
impl SessionStream for Pin<Box<TimeoutStream<TcpStream>>> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.as_mut().set_read_timeout_pinned(timeout);
}
fn peer_addr(&self) -> Result<SocketAddr> {
Ok(self.get_ref().peer_addr()?)
}
}
impl<T: SessionStream> SessionStream for Socks5Stream<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.get_socket_mut().set_read_timeout(timeout)
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.get_socket_ref().peer_addr()
}
}
impl<T: SessionStream> SessionStream for shadowsocks::ProxyClientStream<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.get_mut().set_read_timeout(timeout)
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.get_ref().peer_addr()
}
}
impl<T: SessionStream> SessionStream for async_imap::DeflateStream<T> {
fn set_read_timeout(&mut self, timeout: Option<Duration>) {
self.get_mut().set_read_timeout(timeout)
}
fn peer_addr(&self) -> Result<SocketAddr> {
self.get_ref().peer_addr()
}
}
/// Session stream with a read buffer.