feat: add wildcard pattern support to provider database

This commit is contained in:
link2xt
2024-02-05 21:17:03 +00:00
parent f1688d2b3f
commit f8907e3c83
4 changed files with 366 additions and 396 deletions

View File

@@ -215,8 +215,18 @@ pub async fn get_provider_info(
/// Finds a provider in offline database based on domain.
pub fn get_provider_by_domain(domain: &str) -> Option<&'static Provider> {
if let Some(provider) = PROVIDER_DATA.get(domain.to_lowercase().as_str()) {
return Some(*provider);
let domain = domain.to_lowercase();
for (pattern, provider) in PROVIDER_DATA {
if let Some(suffix) = pattern.strip_prefix('*') {
// Wildcard domain pattern.
//
// For example, `suffix` is ".hermes.radio" for "*.hermes.radio" pattern.
if domain.ends_with(suffix) {
return Some(provider);
}
} else if pattern == domain {
return Some(provider);
}
}
None
@@ -226,33 +236,42 @@ pub fn get_provider_by_domain(domain: &str) -> Option<&'static Provider> {
///
/// For security reasons, only Gmail can be configured this way.
pub async fn get_provider_by_mx(context: &Context, domain: &str) -> Option<&'static Provider> {
if let Ok(resolver) = get_resolver() {
let mut fqdn: String = domain.to_string();
if !fqdn.ends_with('.') {
fqdn.push('.');
let Ok(resolver) = get_resolver() else {
warn!(context, "Cannot get a resolver to check MX records.");
return None;
};
let mut fqdn: String = domain.to_string();
if !fqdn.ends_with('.') {
fqdn.push('.');
}
let Ok(mx_domains) = resolver.mx_lookup(fqdn).await else {
warn!(context, "Cannot resolve MX records for {domain:?}.");
return None;
};
for (provider_domain_pattern, provider) in PROVIDER_DATA {
if provider.id != "gmail" {
// MX lookup is limited to Gmail for security reasons
continue;
}
if let Ok(mx_domains) = resolver.mx_lookup(fqdn).await {
for (provider_domain, provider) in &*PROVIDER_DATA {
if provider.id != "gmail" {
// MX lookup is limited to Gmail for security reasons
continue;
}
if provider_domain_pattern.starts_with('*') {
// Skip wildcard patterns.
continue;
}
let provider_fqdn = provider_domain.to_string() + ".";
let provider_fqdn_dot = ".".to_string() + &provider_fqdn;
let provider_fqdn = provider_domain_pattern.to_string() + ".";
let provider_fqdn_dot = ".".to_string() + &provider_fqdn;
for mx_domain in mx_domains.iter() {
let mx_domain = mx_domain.exchange().to_lowercase().to_utf8();
for mx_domain in mx_domains.iter() {
let mx_domain = mx_domain.exchange().to_lowercase().to_utf8();
if mx_domain == provider_fqdn || mx_domain.ends_with(&provider_fqdn_dot) {
return Some(provider);
}
}
if mx_domain == provider_fqdn || mx_domain.ends_with(&provider_fqdn_dot) {
return Some(provider);
}
}
} else {
warn!(context, "cannot get a resolver to check MX records.");
}
None