diff --git a/src/db.rs b/src/db.rs index df385bc..cb818d3 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,3 +1,5 @@ +use std::time::{Duration, SystemTime}; + use rocket::futures::lock::Mutex; use tokio_postgres::{Client, Error, Statement}; @@ -5,7 +7,9 @@ use crate::GlobalState; pub const STATEMENT_GET_LINK: usize = 0; pub const STATEMENT_ADD_LINK: usize = 1; -pub const N_STATEMENTS: usize = 2; +pub const STATEMENT_ADD_TEMP_LINK: usize = 2; +pub const STATEMENT_DELETE_EXPIRED: usize = 3; +pub const N_STATEMENTS: usize = 4; pub async fn get_link(state: &Mutex, link: &str) -> Result, Error> { let lock = state.lock().await; @@ -18,13 +22,28 @@ pub async fn add_link(state: &Mutex, link: &str, url: &str) -> Resu Ok(()) } +pub async fn add_temporary_link(state: &Mutex, link: &str, url: &str, ttl: u64) -> Result<(), Error> { + let time = SystemTime::now() + Duration::from_secs(ttl); + let lock = state.lock().await; + lock.db_client.execute(&lock.statements[STATEMENT_ADD_TEMP_LINK], &[&link, &url, &time]).await?; + Ok(()) +} + +pub async fn delete_expired_links(state: &Mutex) -> Result<(), Error> { + let lock = state.lock().await; + lock.db_client.execute(&lock.statements[STATEMENT_DELETE_EXPIRED], &[]).await?; + Ok(()) +} + pub async fn prepare_statements(db: &Client) -> Result<[Statement; N_STATEMENTS], Error> { Ok([db.prepare("SELECT url FROM links WHERE id = $1").await?, db.prepare("INSERT INTO links (id, url) VALUES ($1, $2)").await?, + db.prepare("INSERT INTO links (id, url, valid_until) VALUES ($1, $2, $3)").await?, + db.prepare("DELETE FROM links WHERE valid_until < NOW()").await? ]) } pub async fn prepare_tables(db: &Client) -> Result<(), Error> { - db.execute("CREATE TABLE IF NOT EXISTS links (id TEXT PRIMARY KEY, url TEXT NOT NULL)", &[]).await?; + db.execute("CREATE TABLE IF NOT EXISTS links (id TEXT PRIMARY KEY, url TEXT NOT NULL, valid_until TIMESTAMP NOT NULL)", &[]).await?; Ok(()) } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index de05a75..ab7e57f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use db::N_STATEMENTS; use rocket::{futures::lock::Mutex, get, fs::FileServer, http::Status, response::Redirect, response::content::RawHtml, routes, State}; use tokio_postgres::{Client, NoTls, Statement}; -use std::fs::File; +use std::{fs::File, sync::Arc, time::Duration}; mod db; mod linkgen; @@ -17,7 +17,13 @@ pub struct GlobalState { #[tokio::main] async fn main() { let postgres_config = "host=localhost user=shortener password=".to_owned() + include_str!("../postgres_password.txt").trim(); - let (client, conn) = tokio_postgres::connect(&postgres_config, NoTls).await.unwrap(); + let (client, conn) = match tokio_postgres::connect(&postgres_config, NoTls).await { + Ok((client, conn)) => (client, conn), + Err(e) => { + eprintln!("Failed to connect to PostgreSQL: {}", e); + return; + } + }; tokio::spawn(async move { if let Err(e) = conn.await { eprintln!("postgresql error: {}", e); @@ -28,16 +34,26 @@ async fn main() { let mut config = rocket::Config::default(); config.port = 3020; let state = GlobalState { db_client: client, statements }; + let state_mutex = Arc::new(Mutex::new(state)); + { + let state_mutex = state_mutex.clone(); + tokio::spawn(async move { + loop { + delete_expired(&state_mutex).await.unwrap(); + tokio::time::sleep(Duration::from_secs(2 * 3600)).await; + } + }); + } rocket::build() .mount("/", routes![create, go_to_link]) - .manage(Mutex::new(state)) + .manage(state_mutex.clone()) .configure(config) .launch() .await.unwrap(); } -#[get("/create?&&&")] -async fn create(state: &State>, url: &str, secret: Option<&str>, length: Option, link: Option<&str>) -> (Status, String) { +#[get("/create?&&&&")] +async fn create(state: &State>, url: &str, secret: Option<&str>, length: Option, link: Option<&str>, ttl: Option) -> (Status, String) { let mut allow_secret_options = false; if let Some(secret) = secret { if secret == include_str!("../secret.txt").trim() { @@ -83,7 +99,11 @@ async fn create(state: &State>, url: &str, secret: Option<&st } } - match db::add_link(&state, actual_link.as_ref().unwrap(), url).await { + let res = match ttl { + Some(ttl) => db::add_temporary_link(&state, actual_link.as_ref().unwrap(), url, ttl).await, + None => db::add_link(&state, actual_link.as_ref().unwrap(), url).await + }; + match res { Ok(_) => (), Err(e) => { eprintln!("add link error: {}", e); @@ -108,3 +128,7 @@ pub async fn go_to_link(state: &State>, link: &str) -> Result } } } + +async fn delete_expired(state: &Mutex) -> Result<(), tokio_postgres::Error> { + db::delete_expired_links(state).await +}