diff --git a/src/imap/client.rs b/src/imap/client.rs index db9c8e353..acb58b6c2 100644 --- a/src/imap/client.rs +++ b/src/imap/client.rs @@ -11,9 +11,9 @@ use tokio::io::BufWriter; use super::capabilities::Capabilities; use super::session::Session; use crate::context::Context; -use crate::net::connect_tcp; use crate::net::session::SessionStream; use crate::net::tls::wrap_tls; +use crate::net::{connect_starttls_imap, connect_tcp, connect_tls}; use crate::socks::Socks5Config; use fast_socks5::client::Socks5Stream; @@ -104,8 +104,8 @@ impl Client { port: u16, strict_tls: bool, ) -> Result { - let tcp_stream = connect_tcp(context, hostname, port, IMAP_TIMEOUT, strict_tls).await?; - let tls_stream = wrap_tls(strict_tls, hostname, &["imap"], tcp_stream).await?; + let tls_stream = + connect_tls(context, hostname, port, IMAP_TIMEOUT, strict_tls, &["imap"]).await?; let buffered_stream = BufWriter::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); let mut client = Client::new(session_stream); @@ -134,25 +134,8 @@ impl Client { port: u16, strict_tls: bool, ) -> Result { - let tcp_stream = connect_tcp(context, hostname, port, IMAP_TIMEOUT, strict_tls).await?; - - // Run STARTTLS command and convert the client back into a stream. - let buffered_tcp_stream = BufWriter::new(tcp_stream); - let mut client = ImapClient::new(buffered_tcp_stream); - let _greeting = client - .read_response() - .await - .context("failed to read greeting")??; - client - .run_command_and_check_ok("STARTTLS", None) - .await - .context("STARTTLS command failed")?; - let buffered_tcp_stream = client.into_inner(); - let tcp_stream = buffered_tcp_stream.into_inner(); - - let tls_stream = wrap_tls(strict_tls, hostname, &["imap"], tcp_stream) - .await - .context("STARTTLS upgrade failed")?; + let tls_stream = + connect_starttls_imap(context, hostname, port, IMAP_TIMEOUT, strict_tls).await?; let buffered_stream = BufWriter::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); diff --git a/src/net.rs b/src/net.rs index d76f0ff7e..493a46641 100644 --- a/src/net.rs +++ b/src/net.rs @@ -4,12 +4,14 @@ use std::pin::Pin; use std::time::Duration; use anyhow::{format_err, Context as _, Result}; +use async_native_tls::TlsStream; +use tokio::io::BufStream; +use tokio::io::BufWriter; use tokio::net::TcpStream; use tokio::time::timeout; use tokio_io_timeout::TimeoutStream; use crate::context::Context; -use crate::tools::time; pub(crate) mod dns; pub(crate) mod http; @@ -18,21 +20,46 @@ pub(crate) mod tls; use dns::lookup_host_with_cache; pub use http::{read_url, read_url_blob, Response as HttpResponse}; - -async fn connect_tcp_inner(addr: SocketAddr, timeout_val: Duration) -> Result { - let tcp_stream = timeout(timeout_val, TcpStream::connect(addr)) - .await - .context("connection timeout")? - .context("connection failure")?; - Ok(tcp_stream) -} +use tls::wrap_tls; /// Returns a TCP connection stream with read/write timeouts set /// and Nagle's algorithm disabled with `TCP_NODELAY`. /// /// `TCP_NODELAY` ensures writing to the stream always results in immediate sending of the packet /// to the network, which is important to reduce the latency of interactive protocols such as IMAP. -/// +async fn connect_tcp_inner( + addr: SocketAddr, + timeout_val: Duration, +) -> Result>>> { + let tcp_stream = timeout(timeout_val, TcpStream::connect(addr)) + .await + .context("connection timeout")? + .context("connection failure")?; + + // Disable Nagle's algorithm. + tcp_stream.set_nodelay(true)?; + + let mut timeout_stream = TimeoutStream::new(tcp_stream); + timeout_stream.set_write_timeout(Some(timeout_val)); + timeout_stream.set_read_timeout(Some(timeout_val)); + + Ok(Box::pin(timeout_stream)) +} + +/// Attempts to establish TLS connection +/// given the result of the hostname to address resolution. +async fn connect_tls_inner( + addr: SocketAddr, + timeout_val: Duration, + host: &str, + strict_tls: bool, + alpns: &[&str], +) -> Result>>>> { + let tcp_stream = connect_tcp_inner(addr, timeout_val).await?; + let tls_stream = wrap_tls(strict_tls, host, alpns, tcp_stream).await?; + Ok(tls_stream) +} + /// If `load_cache` is true, may use cached DNS results. /// Because the cache may be poisoned with incorrect results by networks hijacking DNS requests, /// this option should only be used when connection is authenticated, @@ -46,7 +73,6 @@ pub(crate) async fn connect_tcp( timeout_val: Duration, load_cache: bool, ) -> Result>>> { - let mut tcp_stream = None; let mut last_error = None; for resolved_addr in @@ -54,30 +80,7 @@ pub(crate) async fn connect_tcp( { match connect_tcp_inner(resolved_addr, timeout_val).await { Ok(stream) => { - tcp_stream = Some(stream); - - // Update timestamp of this cached entry - // or insert a new one if cached entry does not exist. - // - // This increases priority of existing cached entries - // and copies fallback addresses from build-in cache - // into database cache on successful use. - // - // Unlike built-in cache, - // database cache is used even if DNS - // resolver returns a non-empty - // (but potentially incorrect and unusable) result. - context - .sql - .execute( - "INSERT INTO dns_cache (hostname, address, timestamp) - VALUES (?, ?, ?) - ON CONFLICT (hostname, address) - DO UPDATE SET timestamp=excluded.timestamp", - (host, resolved_addr.ip().to_string(), time()), - ) - .await?; - break; + return Ok(stream); } Err(err) => { warn!( @@ -89,22 +92,143 @@ pub(crate) async fn connect_tcp( } } - let tcp_stream = match tcp_stream { - Some(tcp_stream) => tcp_stream, - None => { - return Err( - last_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}")) - ); - } - }; - - // Disable Nagle's algorithm. - tcp_stream.set_nodelay(true)?; - - let mut timeout_stream = TimeoutStream::new(tcp_stream); - timeout_stream.set_write_timeout(Some(timeout_val)); - timeout_stream.set_read_timeout(Some(timeout_val)); - let pinned_stream = Box::pin(timeout_stream); - - Ok(pinned_stream) + Err(last_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) +} + +pub(crate) async fn connect_tls( + context: &Context, + host: &str, + port: u16, + timeout_val: Duration, + strict_tls: bool, + alpns: &[&str], +) -> Result>>>> { + let mut last_error = None; + + for resolved_addr in + lookup_host_with_cache(context, host, port, timeout_val, strict_tls).await? + { + match connect_tls_inner(resolved_addr, timeout_val, host, strict_tls, alpns).await { + Ok(tls_stream) => { + if strict_tls { + dns::update_connect_timestamp(context, host, &resolved_addr.ip().to_string()) + .await?; + } + return Ok(tls_stream); + } + Err(err) => { + warn!(context, "Failed to connect to {resolved_addr}: {err:#}."); + last_error = Some(err); + } + } + } + + Err(last_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) +} + +async fn connect_starttls_imap_inner( + addr: SocketAddr, + host: &str, + timeout_val: Duration, + strict_tls: bool, +) -> Result>>>> { + let tcp_stream = connect_tcp_inner(addr, timeout_val).await?; + + // Run STARTTLS command and convert the client back into a stream. + let buffered_tcp_stream = BufWriter::new(tcp_stream); + let mut client = async_imap::Client::new(buffered_tcp_stream); + let _greeting = client + .read_response() + .await + .context("failed to read greeting")??; + client + .run_command_and_check_ok("STARTTLS", None) + .await + .context("STARTTLS command failed")?; + let buffered_tcp_stream = client.into_inner(); + let tcp_stream = buffered_tcp_stream.into_inner(); + + let tls_stream = wrap_tls(strict_tls, host, &["imap"], tcp_stream) + .await + .context("STARTTLS upgrade failed")?; + + Ok(tls_stream) +} + +pub(crate) async fn connect_starttls_imap( + context: &Context, + host: &str, + port: u16, + timeout_val: Duration, + strict_tls: bool, +) -> Result>>>> { + let mut last_error = None; + + for resolved_addr in + lookup_host_with_cache(context, host, port, timeout_val, strict_tls).await? + { + match connect_starttls_imap_inner(resolved_addr, host, timeout_val, strict_tls).await { + Ok(tls_stream) => { + if strict_tls { + dns::update_connect_timestamp(context, host, &resolved_addr.ip().to_string()) + .await?; + } + return Ok(tls_stream); + } + Err(err) => { + warn!(context, "Failed to connect to {resolved_addr}: {err:#}."); + last_error = Some(err); + } + } + } + + Err(last_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) +} + +async fn connect_starttls_smtp_inner( + addr: SocketAddr, + host: &str, + timeout_val: Duration, + strict_tls: bool, +) -> Result>>>> { + let tcp_stream = connect_tcp_inner(addr, timeout_val).await?; + + // Run STARTTLS command and convert the client back into a stream. + let client = async_smtp::SmtpClient::new().smtp_utf8(true); + let transport = async_smtp::SmtpTransport::new(client, BufStream::new(tcp_stream)).await?; + let tcp_stream = transport.starttls().await?.into_inner(); + let tls_stream = wrap_tls(strict_tls, host, &["smtp"], tcp_stream) + .await + .context("STARTTLS upgrade failed")?; + Ok(tls_stream) +} + +pub(crate) async fn connect_starttls_smtp( + context: &Context, + host: &str, + port: u16, + timeout_val: Duration, + strict_tls: bool, +) -> Result>>>> { + let mut last_error = None; + + for resolved_addr in + lookup_host_with_cache(context, host, port, timeout_val, strict_tls).await? + { + match connect_starttls_smtp_inner(resolved_addr, host, timeout_val, strict_tls).await { + Ok(tls_stream) => { + if strict_tls { + dns::update_connect_timestamp(context, host, &resolved_addr.ip().to_string()) + .await?; + } + return Ok(tls_stream); + } + Err(err) => { + warn!(context, "Failed to connect to {resolved_addr}: {err:#}."); + last_error = Some(err); + } + } + } + + Err(last_error.unwrap_or_else(|| format_err!("no DNS resolution results for {host}"))) } diff --git a/src/net/dns.rs b/src/net/dns.rs index 4fe75489e..96dbaca36 100644 --- a/src/net/dns.rs +++ b/src/net/dns.rs @@ -22,6 +22,42 @@ async fn lookup_host_with_timeout( Ok(res.collect()) } +// Updates timestamp of the cached entry +// or inserts a new one if cached entry does not exist. +// +// This function should be called when a successful TLS +// connection is established with strict TLS checks. +// +// This increases priority of existing cached entries +// and copies fallback addresses from built-in cache +// into database cache on successful use. +// +// Unlike built-in cache, +// database cache is used even if DNS +// resolver returns a non-empty +// (but potentially incorrect and unusable) result. +pub(crate) async fn update_connect_timestamp( + context: &Context, + host: &str, + address: &str, +) -> Result<()> { + if host == address { + return Ok(()); + } + + context + .sql + .execute( + "INSERT INTO dns_cache (hostname, address, timestamp) + VALUES (?, ?, ?) + ON CONFLICT (hostname, address) + DO UPDATE SET timestamp=excluded.timestamp", + (host, address, time()), + ) + .await?; + Ok(()) +} + /// Looks up hostname and port using DNS and updates the address resolution cache. /// /// If `load_cache` is true, appends cached results not older than 30 days to the end @@ -39,7 +75,7 @@ pub(crate) async fn lookup_host_with_cache( Err(err) => { warn!( context, - "DNS resolution for {}:{} failed: {:#}.", hostname, port, err + "DNS resolution for {hostname}:{port} failed: {err:#}." ); Vec::new() } diff --git a/src/smtp.rs b/src/smtp.rs index 0b626aba8..bcb4f8264 100644 --- a/src/smtp.rs +++ b/src/smtp.rs @@ -19,9 +19,9 @@ use crate::login_param::{CertificateChecks, LoginParam, ServerLoginParam}; use crate::message::Message; use crate::message::{self, MsgId}; use crate::mimefactory::MimeFactory; -use crate::net::connect_tcp; use crate::net::session::SessionBufStream; use crate::net::tls::wrap_tls; +use crate::net::{connect_starttls_smtp, connect_tcp, connect_tls}; use crate::oauth2::get_oauth2_access_token; use crate::provider::Socket; use crate::scheduler::connectivity::ConnectivityStore; @@ -178,8 +178,8 @@ impl Smtp { port: u16, strict_tls: bool, ) -> Result>> { - let tcp_stream = connect_tcp(context, hostname, port, SMTP_TIMEOUT, strict_tls).await?; - let tls_stream = wrap_tls(strict_tls, hostname, &["smtp"], tcp_stream).await?; + let tls_stream = + connect_tls(context, hostname, port, SMTP_TIMEOUT, strict_tls, &["smtp"]).await?; let buffered_stream = BufStream::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); let client = smtp::SmtpClient::new().smtp_utf8(true); @@ -194,15 +194,9 @@ impl Smtp { port: u16, strict_tls: bool, ) -> Result>> { - let tcp_stream = connect_tcp(context, hostname, port, SMTP_TIMEOUT, strict_tls).await?; + let tls_stream = + connect_starttls_smtp(context, hostname, port, SMTP_TIMEOUT, strict_tls).await?; - // Run STARTTLS command and convert the client back into a stream. - let client = smtp::SmtpClient::new().smtp_utf8(true); - let transport = SmtpTransport::new(client, BufStream::new(tcp_stream)).await?; - let tcp_stream = transport.starttls().await?.into_inner(); - let tls_stream = wrap_tls(strict_tls, hostname, &["smtp"], tcp_stream) - .await - .context("STARTTLS upgrade failed")?; let buffered_stream = BufStream::new(tls_stream); let session_stream: Box = Box::new(buffered_stream); let client = smtp::SmtpClient::new().smtp_utf8(true).without_greeting();