diff --git a/src/main.rs b/src/main.rs index ab7e57f..f899a8b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,17 @@ use db::N_STATEMENTS; -use rocket::{futures::lock::Mutex, get, fs::FileServer, http::Status, response::Redirect, response::content::RawHtml, routes, State}; +use rocket::{ + fs::FileServer, futures::lock::Mutex, get, http::Status, response::content::RawHtml, + response::Redirect, routes, State, +}; +use std::{ + env, + fs::File, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + str::FromStr, + sync::Arc, + time::Duration, +}; use tokio_postgres::{Client, NoTls, Statement}; -use std::{fs::File, sync::Arc, time::Duration}; mod db; mod linkgen; @@ -16,7 +26,8 @@ pub struct GlobalState { } #[tokio::main] async fn main() { - let postgres_config = "host=localhost user=shortener password=".to_owned() + include_str!("../postgres_password.txt").trim(); + let postgres_config = "host=localhost user=shortener password=".to_owned() + + include_str!("../postgres_password.txt").trim(); let (client, conn) = match tokio_postgres::connect(&postgres_config, NoTls).await { Ok((client, conn)) => (client, conn), Err(e) => { @@ -31,9 +42,21 @@ async fn main() { }); db::prepare_tables(&client).await.unwrap(); let statements = db::prepare_statements(&client).await.unwrap(); + + let default_listen_addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 3020).into(); + let listen_addr = env::var("SERVER_ADDR") + .ok() + .map(|v| SocketAddr::from_str(&v).ok()) + .flatten() + .unwrap_or(default_listen_addr); let mut config = rocket::Config::default(); - config.port = 3020; - let state = GlobalState { db_client: client, statements }; + config.address = listen_addr.ip(); + config.port = listen_addr.port(); + + let state = GlobalState { + db_client: client, + statements, + }; let state_mutex = Arc::new(Mutex::new(state)); { let state_mutex = state_mutex.clone(); @@ -68,7 +91,11 @@ async fn create(state: &State>, url: &str, secret: Option<&st custom_link = true; } } - let actual_len = if allow_secret_options && length.is_some() { length.unwrap() } else { DEFAULT_LINK_LENGTH }; + let actual_len = if allow_secret_options && length.is_some() { + length.unwrap() + } else { + DEFAULT_LINK_LENGTH + }; if actual_len < 1 || actual_len > 64 { return (Status::BadRequest, "invalid length".to_owned()); } @@ -86,7 +113,7 @@ async fn create(state: &State>, url: &str, secret: Option<&st if custom_link { return (Status::BadRequest, "link already exists".to_owned()); } - }, + } None => { actual_link = Some(link_id.clone()); break; @@ -101,7 +128,7 @@ async fn create(state: &State>, url: &str, secret: Option<&st let res = match ttl { Some(ttl) => db::add_temporary_link(&state, actual_link.as_ref().unwrap(), url, ttl).await, - None => db::add_link(&state, actual_link.as_ref().unwrap(), url).await + None => db::add_link(&state, actual_link.as_ref().unwrap(), url).await, }; match res { Ok(_) => (), @@ -116,7 +143,10 @@ async fn create(state: &State>, url: &str, secret: Option<&st } #[get("/")] -pub async fn go_to_link(state: &State>, link: &str) -> Result, Status> { +pub async fn go_to_link( + state: &State>>, + link: &str, +) -> Result, Status> { match db::get_link(&state, link).await { Ok(url) => match url { Some(url) => Ok(Some(Redirect::to(url))),