From ef925b09489cab4921584ee7932b0d8d5db98cad Mon Sep 17 00:00:00 2001 From: link2xt Date: Sun, 28 Jul 2024 21:03:50 +0000 Subject: [PATCH] refactor: move DNS resolution into IMAP and SMTP connect code --- src/imap/client.rs | 90 +++++++++++++++++++---------- src/net.rs | 135 ++------------------------------------------ src/smtp/connect.rs | 84 +++++++++++++++++---------- 3 files changed, 119 insertions(+), 190 deletions(-) diff --git a/src/imap/client.rs b/src/imap/client.rs index 97a31a5c6..6f967fd39 100644 --- a/src/imap/client.rs +++ b/src/imap/client.rs @@ -1,6 +1,7 @@ +use std::net::SocketAddr; use std::ops::{Deref, DerefMut}; -use anyhow::{bail, Context as _, Result}; +use anyhow::{bail, format_err, Context as _, Result}; use async_imap::Client as ImapClient; use async_imap::Session as ImapSession; use tokio::io::BufWriter; @@ -8,9 +9,10 @@ use tokio::io::BufWriter; use super::capabilities::Capabilities; use super::session::Session; use crate::context::Context; +use crate::net::dns::{lookup_host_with_cache, update_connect_timestamp}; use crate::net::session::SessionStream; use crate::net::tls::wrap_tls; -use crate::net::{connect_starttls_imap, connect_tcp, connect_tls}; +use crate::net::{connect_tcp_inner, connect_tls_inner}; use crate::provider::Socket; use crate::socks::Socks5Config; use fast_socks5::client::Socks5Stream; @@ -102,37 +104,54 @@ impl Client { security: Socket, ) -> Result { if let Some(socks5_config) = socks5_config { - match security { + let client = match security { Socket::Automatic => bail!("IMAP port security is not configured"), Socket::Ssl => { Client::connect_secure_socks5(context, host, port, strict_tls, socks5_config) - .await + .await? } Socket::Starttls => { Client::connect_starttls_socks5(context, host, port, socks5_config, strict_tls) - .await + .await? } Socket::Plain => { - Client::connect_insecure_socks5(context, host, port, socks5_config).await + Client::connect_insecure_socks5(context, host, port, socks5_config).await? + } + }; + Ok(client) + } else { + let mut first_error = None; + let load_cache = + strict_tls && (security == Socket::Ssl || security == Socket::Starttls); + for resolved_addr in lookup_host_with_cache(context, host, port, load_cache).await? { + let res = match security { + Socket::Automatic => bail!("IMAP port security is not configured"), + Socket::Ssl => Client::connect_secure(resolved_addr, host, strict_tls).await, + Socket::Starttls => { + Client::connect_starttls(resolved_addr, host, strict_tls).await + } + Socket::Plain => Client::connect_insecure(resolved_addr).await, + }; + match res { + Ok(client) => { + let ip_addr = resolved_addr.ip().to_string(); + if load_cache { + update_connect_timestamp(context, host, &ip_addr).await?; + } + return Ok(client); + } + Err(err) => { + warn!(context, "Failed to connect to {resolved_addr}: {err:#}."); + first_error.get_or_insert(err); + } } } - } else { - match security { - Socket::Automatic => bail!("IMAP port security is not configured"), - Socket::Ssl => Client::connect_secure(context, host, port, strict_tls).await, - Socket::Starttls => Client::connect_starttls(context, host, port, strict_tls).await, - Socket::Plain => Client::connect_insecure(context, host, port).await, - } + Err(first_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) } } - async fn connect_secure( - context: &Context, - hostname: &str, - port: u16, - strict_tls: bool, - ) -> Result { - let tls_stream = connect_tls(context, hostname, port, strict_tls, "imap").await?; + async fn connect_secure(addr: SocketAddr, hostname: &str, strict_tls: bool) -> Result { + let tls_stream = connect_tls_inner(addr, hostname, strict_tls, "imap").await?; let buffered_stream = BufWriter::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); let mut client = Client::new(session_stream); @@ -143,8 +162,8 @@ impl Client { Ok(client) } - async fn connect_insecure(context: &Context, hostname: &str, port: u16) -> Result { - let tcp_stream = connect_tcp(context, hostname, port, false).await?; + async fn connect_insecure(addr: SocketAddr) -> Result { + let tcp_stream = connect_tcp_inner(addr).await?; let buffered_stream = BufWriter::new(tcp_stream); let session_stream: Box = Box::new(buffered_stream); let mut client = Client::new(session_stream); @@ -155,13 +174,26 @@ impl Client { Ok(client) } - async fn connect_starttls( - context: &Context, - hostname: &str, - port: u16, - strict_tls: bool, - ) -> Result { - let tls_stream = connect_starttls_imap(context, hostname, port, strict_tls).await?; + async fn connect_starttls(addr: SocketAddr, host: &str, strict_tls: bool) -> Result { + let tcp_stream = connect_tcp_inner(addr).await?; + + // Run STARTTLS command and convert the client back into a stream. + let buffered_tcp_stream = BufWriter::new(tcp_stream); + let mut client = async_imap::Client::new(buffered_tcp_stream); + let _greeting = client + .read_response() + .await + .context("failed to read greeting")??; + client + .run_command_and_check_ok("STARTTLS", None) + .await + .context("STARTTLS command failed")?; + let buffered_tcp_stream = client.into_inner(); + let tcp_stream = buffered_tcp_stream.into_inner(); + + let tls_stream = wrap_tls(strict_tls, host, "imap", tcp_stream) + .await + .context("STARTTLS upgrade failed")?; let buffered_stream = BufWriter::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); diff --git a/src/net.rs b/src/net.rs index 535f0ba7b..664e1b466 100644 --- a/src/net.rs +++ b/src/net.rs @@ -5,8 +5,6 @@ use std::time::Duration; use anyhow::{format_err, Context as _, Result}; use async_native_tls::TlsStream; -use tokio::io::BufStream; -use tokio::io::BufWriter; use tokio::net::TcpStream; use tokio::time::timeout; use tokio_io_timeout::TimeoutStream; @@ -32,7 +30,9 @@ pub(crate) const TIMEOUT: Duration = Duration::from_secs(60); /// /// `TCP_NODELAY` ensures writing to the stream always results in immediate sending of the packet /// to the network, which is important to reduce the latency of interactive protocols such as IMAP. -async fn connect_tcp_inner(addr: SocketAddr) -> Result>>> { +pub(crate) async fn connect_tcp_inner( + addr: SocketAddr, +) -> Result>>> { let tcp_stream = timeout(TIMEOUT, TcpStream::connect(addr)) .await .context("connection timeout")? @@ -50,7 +50,7 @@ async fn connect_tcp_inner(addr: SocketAddr) -> Result Result>>>> { - let mut first_error = None; - - for resolved_addr in lookup_host_with_cache(context, host, port, strict_tls).await? { - match connect_tls_inner(resolved_addr, host, strict_tls, alpn).await { - Ok(tls_stream) => { - if strict_tls { - dns::update_connect_timestamp(context, host, &resolved_addr.ip().to_string()) - .await?; - } - return Ok(tls_stream); - } - Err(err) => { - warn!(context, "Failed to connect to {resolved_addr}: {err:#}."); - first_error.get_or_insert(err); - } - } - } - - Err(first_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) -} - -async fn connect_starttls_imap_inner( - addr: SocketAddr, - host: &str, - strict_tls: bool, -) -> Result>>>> { - let tcp_stream = connect_tcp_inner(addr).await?; - - // Run STARTTLS command and convert the client back into a stream. - let buffered_tcp_stream = BufWriter::new(tcp_stream); - let mut client = async_imap::Client::new(buffered_tcp_stream); - let _greeting = client - .read_response() - .await - .context("failed to read greeting")??; - client - .run_command_and_check_ok("STARTTLS", None) - .await - .context("STARTTLS command failed")?; - let buffered_tcp_stream = client.into_inner(); - let tcp_stream = buffered_tcp_stream.into_inner(); - - let tls_stream = wrap_tls(strict_tls, host, "imap", tcp_stream) - .await - .context("STARTTLS upgrade failed")?; - - Ok(tls_stream) -} - -pub(crate) async fn connect_starttls_imap( - context: &Context, - host: &str, - port: u16, - strict_tls: bool, -) -> Result>>>> { - let mut first_error = None; - - for resolved_addr in lookup_host_with_cache(context, host, port, strict_tls).await? { - match connect_starttls_imap_inner(resolved_addr, host, strict_tls).await { - Ok(tls_stream) => { - if strict_tls { - dns::update_connect_timestamp(context, host, &resolved_addr.ip().to_string()) - .await?; - } - return Ok(tls_stream); - } - Err(err) => { - warn!(context, "Failed to connect to {resolved_addr}: {err:#}."); - first_error.get_or_insert(err); - } - } - } - - Err(first_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) -} - -async fn connect_starttls_smtp_inner( - addr: SocketAddr, - host: &str, - strict_tls: bool, -) -> Result>>>> { - let tcp_stream = connect_tcp_inner(addr).await?; - - // Run STARTTLS command and convert the client back into a stream. - let client = async_smtp::SmtpClient::new().smtp_utf8(true); - let transport = async_smtp::SmtpTransport::new(client, BufStream::new(tcp_stream)).await?; - let tcp_stream = transport.starttls().await?.into_inner(); - let tls_stream = wrap_tls(strict_tls, host, "smtp", tcp_stream) - .await - .context("STARTTLS upgrade failed")?; - Ok(tls_stream) -} - -pub(crate) async fn connect_starttls_smtp( - context: &Context, - host: &str, - port: u16, - strict_tls: bool, -) -> Result>>>> { - let mut first_error = None; - - for resolved_addr in lookup_host_with_cache(context, host, port, strict_tls).await? { - match connect_starttls_smtp_inner(resolved_addr, host, strict_tls).await { - Ok(tls_stream) => { - if strict_tls { - dns::update_connect_timestamp(context, host, &resolved_addr.ip().to_string()) - .await?; - } - return Ok(tls_stream); - } - Err(err) => { - warn!(context, "Failed to connect to {resolved_addr}: {err:#}."); - first_error.get_or_insert(err); - } - } - } - - Err(first_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) -} diff --git a/src/smtp/connect.rs b/src/smtp/connect.rs index 5bbc7c104..2edc0c2f9 100644 --- a/src/smtp/connect.rs +++ b/src/smtp/connect.rs @@ -1,13 +1,16 @@ //! SMTP connection establishment. -use anyhow::{bail, Context as _, Result}; +use std::net::SocketAddr; + +use anyhow::{bail, format_err, Context as _, Result}; use async_smtp::{SmtpClient, SmtpTransport}; use tokio::io::BufStream; use crate::context::Context; +use crate::net::dns::{lookup_host_with_cache, update_connect_timestamp}; use crate::net::session::SessionBufStream; use crate::net::tls::wrap_tls; -use crate::net::{connect_starttls_smtp, connect_tcp, connect_tls}; +use crate::net::{connect_tcp_inner, connect_tls_inner}; use crate::provider::Socket; use crate::socks::Socks5Config; @@ -21,36 +24,55 @@ use crate::socks::Socks5Config; /// to unify the result regardless of whether TLS or STARTTLS is used. pub(crate) async fn connect_stream( context: &Context, - domain: &str, + host: &str, port: u16, strict_tls: bool, socks5_config: Option, security: Socket, ) -> Result> { - let stream = if let Some(socks5_config) = socks5_config { - match security { + if let Some(socks5_config) = socks5_config { + let stream = match security { Socket::Automatic => bail!("SMTP port security is not configured"), Socket::Ssl => { - connect_secure_socks5(context, domain, port, strict_tls, socks5_config.clone()) + connect_secure_socks5(context, host, port, strict_tls, socks5_config.clone()) .await? } Socket::Starttls => { - connect_starttls_socks5(context, domain, port, strict_tls, socks5_config.clone()) + connect_starttls_socks5(context, host, port, strict_tls, socks5_config.clone()) .await? } Socket::Plain => { - connect_insecure_socks5(context, domain, port, socks5_config.clone()).await? + connect_insecure_socks5(context, host, port, socks5_config.clone()).await? + } + }; + Ok(stream) + } else { + let mut first_error = None; + let load_cache = strict_tls && (security == Socket::Ssl || security == Socket::Starttls); + + for resolved_addr in lookup_host_with_cache(context, host, port, load_cache).await? { + let res = match security { + Socket::Automatic => bail!("SMTP port security is not configured"), + Socket::Ssl => connect_secure(resolved_addr, host, strict_tls).await, + Socket::Starttls => connect_starttls(resolved_addr, host, strict_tls).await, + Socket::Plain => connect_insecure(resolved_addr).await, + }; + match res { + Ok(stream) => { + let ip_addr = resolved_addr.ip().to_string(); + if load_cache { + update_connect_timestamp(context, host, &ip_addr).await?; + } + return Ok(stream); + } + Err(err) => { + warn!(context, "Failed to connect to {resolved_addr}: {err:#}."); + first_error.get_or_insert(err); + } } } - } else { - match security { - Socket::Automatic => bail!("SMTP port security is not configured"), - Socket::Ssl => connect_secure(context, domain, port, strict_tls).await?, - Socket::Starttls => connect_starttls(context, domain, port, strict_tls).await?, - Socket::Plain => connect_insecure(context, domain, port).await?, - } - }; - Ok(stream) + Err(first_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) + } } /// Reads and ignores SMTP greeting. @@ -132,12 +154,11 @@ async fn connect_insecure_socks5( } async fn connect_secure( - context: &Context, + addr: SocketAddr, hostname: &str, - port: u16, strict_tls: bool, ) -> Result> { - let tls_stream = connect_tls(context, hostname, port, strict_tls, "smtp").await?; + let tls_stream = connect_tls_inner(addr, hostname, strict_tls, "smtp").await?; let mut buffered_stream = BufStream::new(tls_stream); skip_smtp_greeting(&mut buffered_stream).await?; let session_stream: Box = Box::new(buffered_stream); @@ -145,24 +166,27 @@ async fn connect_secure( } async fn connect_starttls( - context: &Context, - hostname: &str, - port: u16, + addr: SocketAddr, + host: &str, strict_tls: bool, ) -> Result> { - let tls_stream = connect_starttls_smtp(context, hostname, port, strict_tls).await?; + let tcp_stream = connect_tcp_inner(addr).await?; + + // Run STARTTLS command and convert the client back into a stream. + let client = async_smtp::SmtpClient::new().smtp_utf8(true); + let transport = async_smtp::SmtpTransport::new(client, BufStream::new(tcp_stream)).await?; + let tcp_stream = transport.starttls().await?.into_inner(); + let tls_stream = wrap_tls(strict_tls, host, "smtp", tcp_stream) + .await + .context("STARTTLS upgrade failed")?; let buffered_stream = BufStream::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); Ok(session_stream) } -async fn connect_insecure( - context: &Context, - hostname: &str, - port: u16, -) -> Result> { - let tcp_stream = connect_tcp(context, hostname, port, false).await?; +async fn connect_insecure(addr: SocketAddr) -> Result> { + let tcp_stream = connect_tcp_inner(addr).await?; let mut buffered_stream = BufStream::new(tcp_stream); skip_smtp_greeting(&mut buffered_stream).await?; let session_stream: Box = Box::new(buffered_stream);