refactor: merge build_tls() function into wrap_tls()

This commit is contained in:
link2xt
2024-09-21 19:25:06 +00:00
parent fe0c9958a6
commit 638da904e7

View File

@@ -14,41 +14,42 @@ static LETSENCRYPT_ROOT: Lazy<Certificate> = Lazy::new(|| {
.unwrap() .unwrap()
}); });
pub fn build_tls(strict_tls: bool, alpns: &[&str]) -> TlsConnector {
let tls_builder = TlsConnector::new()
.min_protocol_version(Some(Protocol::Tlsv12))
.request_alpns(alpns)
.add_root_certificate(LETSENCRYPT_ROOT.clone());
if strict_tls {
tls_builder
} else {
tls_builder
.danger_accept_invalid_hostnames(true)
.danger_accept_invalid_certs(true)
}
}
pub async fn wrap_tls<T: AsyncRead + AsyncWrite + Unpin>( pub async fn wrap_tls<T: AsyncRead + AsyncWrite + Unpin>(
strict_tls: bool, strict_tls: bool,
hostname: &str, hostname: &str,
alpn: &[&str], alpn: &[&str],
stream: T, stream: T,
) -> Result<TlsStream<T>> { ) -> Result<TlsStream<T>> {
let tls = build_tls(strict_tls, alpn); let tls_builder = TlsConnector::new()
.min_protocol_version(Some(Protocol::Tlsv12))
.request_alpns(alpn)
.add_root_certificate(LETSENCRYPT_ROOT.clone());
let tls = if strict_tls {
tls_builder
} else {
tls_builder
.danger_accept_invalid_hostnames(true)
.danger_accept_invalid_certs(true)
};
let tls_stream = tls.connect(hostname, stream).await?; let tls_stream = tls.connect(hostname, stream).await?;
Ok(tls_stream) Ok(tls_stream)
} }
#[cfg(test)] pub async fn wrap_rustls<T: AsyncRead + AsyncWrite + Unpin>(
mod tests { hostname: &str,
use super::*; alpn: &[&str],
stream: T,
) -> Result<tokio_rustls::client::TlsStream<T>> {
let mut root_cert_store = rustls::RootCertStore::empty();
root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
#[test] let mut config = rustls::ClientConfig::builder()
fn test_build_tls() { .with_root_certificates(root_cert_store)
// we are using some additional root certificates. .with_no_client_auth();
// make sure, they do not break construction of TlsConnector config.alpn_protocols = alpn.into_iter().map(|s| s.as_bytes().to_vec()).collect();
let _ = build_tls(true, &[]);
let _ = build_tls(false, &[]); let tls = tokio_rustls::TlsConnector::from(Arc::new(config));
} let name = rustls_pki_types::ServerName::try_from(hostname)?.to_owned();
let tls_stream = tls.connect(name, stream).await?;
Ok(tls_stream)
} }