make SERVER_ADDR env variable override the listening address

This commit is contained in:
2026-03-09 12:40:11 +03:00
parent f510b73d05
commit a430e7f895

View File

@@ -1,7 +1,17 @@
use db::N_STATEMENTS;
use rocket::{futures::lock::Mutex, get, fs::FileServer, http::Status, response::Redirect, response::content::RawHtml, routes, State};
use rocket::{
fs::FileServer, futures::lock::Mutex, get, http::Status, response::content::RawHtml,
response::Redirect, routes, State,
};
use std::{
env,
fs::File,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
str::FromStr,
sync::Arc,
time::Duration,
};
use tokio_postgres::{Client, NoTls, Statement};
use std::{fs::File, sync::Arc, time::Duration};
mod db;
mod linkgen;
@@ -16,7 +26,8 @@ pub struct GlobalState {
}
#[tokio::main]
async fn main() {
let postgres_config = "host=localhost user=shortener password=".to_owned() + include_str!("../postgres_password.txt").trim();
let postgres_config = "host=localhost user=shortener password=".to_owned()
+ include_str!("../postgres_password.txt").trim();
let (client, conn) = match tokio_postgres::connect(&postgres_config, NoTls).await {
Ok((client, conn)) => (client, conn),
Err(e) => {
@@ -31,9 +42,21 @@ async fn main() {
});
db::prepare_tables(&client).await.unwrap();
let statements = db::prepare_statements(&client).await.unwrap();
let default_listen_addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3020).into();
let listen_addr = env::var("SERVER_ADDR")
.ok()
.map(|v| SocketAddr::from_str(&v).ok())
.flatten()
.unwrap_or(default_listen_addr);
let mut config = rocket::Config::default();
config.port = 3020;
let state = GlobalState { db_client: client, statements };
config.address = listen_addr.ip();
config.port = listen_addr.port();
let state = GlobalState {
db_client: client,
statements,
};
let state_mutex = Arc::new(Mutex::new(state));
{
let state_mutex = state_mutex.clone();
@@ -68,7 +91,11 @@ async fn create(state: &State<Mutex<GlobalState>>, url: &str, secret: Option<&st
custom_link = true;
}
}
let actual_len = if allow_secret_options && length.is_some() { length.unwrap() } else { DEFAULT_LINK_LENGTH };
let actual_len = if allow_secret_options && length.is_some() {
length.unwrap()
} else {
DEFAULT_LINK_LENGTH
};
if actual_len < 1 || actual_len > 64 {
return (Status::BadRequest, "invalid length".to_owned());
}
@@ -86,7 +113,7 @@ async fn create(state: &State<Mutex<GlobalState>>, url: &str, secret: Option<&st
if custom_link {
return (Status::BadRequest, "link already exists".to_owned());
}
},
}
None => {
actual_link = Some(link_id.clone());
break;
@@ -101,7 +128,7 @@ async fn create(state: &State<Mutex<GlobalState>>, url: &str, secret: Option<&st
let res = match ttl {
Some(ttl) => db::add_temporary_link(&state, actual_link.as_ref().unwrap(), url, ttl).await,
None => db::add_link(&state, actual_link.as_ref().unwrap(), url).await
None => db::add_link(&state, actual_link.as_ref().unwrap(), url).await,
};
match res {
Ok(_) => (),
@@ -116,7 +143,10 @@ async fn create(state: &State<Mutex<GlobalState>>, url: &str, secret: Option<&st
}
#[get("/<link>")]
pub async fn go_to_link(state: &State<Mutex<GlobalState>>, link: &str) -> Result<Option<Redirect>, Status> {
pub async fn go_to_link(
state: &State<Arc<Mutex<GlobalState>>>,
link: &str,
) -> Result<Option<Redirect>, Status> {
match db::get_link(&state, link).await {
Ok(url) => match url {
Some(url) => Ok(Some(Redirect::to(url))),