diff --git a/server/src/invites.rs b/server/src/invites.rs index aa1594a..52dc15b 100644 --- a/server/src/invites.rs +++ b/server/src/invites.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use mysql_async; -use mysql_async::Conn; +use mysql_async::{Conn, Pool}; use mysql_async::error::Error; use mysql_async::prelude::{params, Queryable}; @@ -14,12 +14,17 @@ struct InviteRow { expires: u64, uses: i32, } +/* + * Error handling: + * All errors raisable from this module come from mysql_async and thus + * are of the enum mysql_async::error::Error +*/ impl InviteRow { pub fn new() -> InviteRow { let dt = Utc::now() + chrono::Duration::minutes(30); // TODO:[maybe] ensure no collisions by doing a quick database check here - let mut invite = InviteRow { + let invite = InviteRow { id: random::(), // hopefully there won't ever be collision with this size of pool uses: 1, // default/hardcorded for now expires: dt.timestamp() as u64 @@ -47,46 +52,50 @@ impl InviteRow { } } -async fn get_invite_by_code(conn: &Conn, value: Option<&str>) -> Option { +async fn get_invite_by_code(pool: &Pool, value: Option<&str>) -> Result, Error> { if let Some(val) = value { - let db_row_result: Result<(Conn, Option<(u64, u64, i32)>), Error> = conn - .first_exec(r"SELECT * FROM", mysql_async::params!{"code"=>}) - .await; - match db_row_result { - Ok(data) => { - if let Some(tup) = data.1 {Some(InviteRow::from_tuple(tup))} - else {None} - } - Err(_) => None, + let conn = pool.get_conn().await?; + let db_row_result: (Conn, Option<(u64, u64, i32)>) = conn + .first_exec(r"SELECT * FROM", mysql_async::params!{"code"=>val}) + .await?; + if let Some(tup) = db_row_result.1 { + Ok(Some(InviteRow::from_tuple(tup))) + } + else { + // basically nothing was found but nothing bad happened + Ok(None) } } - else { - None - } + // again db didn't throw a fit but we don't have a good input + else {Ok(None)} } -async fn record_invite_usage(conn: &Conn, data: &InviteRow) { +async fn record_invite_usage(pool: &Pool, data: &InviteRow) -> Result<(), Error>{ /* * By this this is called we really don't care about what happens as we've * already been querying the db and the likely hood of this seriously failing * is low enough to write a wall of text and not a wall of error handling code */ + let conn = pool.get_conn().await?; let _db_result = conn .prep_exec(r"UPDATE invites SET uses = :uses WHERE id = :id", mysql_async::params!{ "uses" => data.uses - 1, "id" => data.id - }).await; + }).await?; + + Ok(()) } -pub async fn join_invite_code(conn: &Conn, response: &mut Response, params: &HashMap<&str, &str>) { +pub async fn join_invite_code(pool: &Pool, response: &mut Response, params: &HashMap<&str, &str>) -> Result<(), Error> { // First check that the code is there match params.get("code") { Some(p) => { - if let Some(row) = get_invite_by_code(conn, Some(*p)).await { + if let Some(row) = get_invite_by_code(pool, Some(*p)).await? { // since we have a row make sure the invite is valid let now = Utc::now().timestamp() as u64; + // usable and expires in the future if row.uses > 0 && row.expires > now { - record_invite_usage(conn, &row).await; + record_invite_usage(pool, &row).await?; // TODO: assign some actual data to the body *response.status_mut() = StatusCode::OK; } @@ -96,23 +105,20 @@ pub async fn join_invite_code(conn: &Conn, response: &mut Response, params *response.status_mut() = StatusCode::BAD_REQUEST; } } + Ok(()) } -pub async fn create_invite(conn: &Conn, response: &mut Response) { +pub async fn create_invite(pool: &Pool, response: &mut Response) -> Result<(), Error> { let invite = InviteRow::new(); - let db_result = conn - .prep_exec(r"INSERT INTO invites (id, expires, uses) VALUES (:id, :expires, :uses", + let conn = pool.get_conn().await?; + conn.prep_exec(r"INSERT INTO invites (id, expires, uses) VALUES (:id, :expires, :uses", mysql_async::params!{ "id" => invite.id, "expires" => invite.expires, - "uses" => invite.uses - }).await; + "uses" => invite.uses, + }).await?; - match db_result { - Ok(d) => { - *response.body_mut() = Body::from(invite.as_json_str()); - *response.status_mut() = StatusCode::OK; - } - Err(e) => {} - } + *response.body_mut() = Body::from(invite.as_json_str()); + *response.status_mut() = StatusCode::OK; + Ok(()) } \ No newline at end of file diff --git a/server/src/main.rs b/server/src/main.rs index b8bd142..85efae0 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -16,7 +16,7 @@ use hyper::{ Method, StatusCode, service::{make_service_fn, service_fn} }; -use mysql_async::Conn; +use mysql_async::Pool; use dotenv::dotenv; mod auth; @@ -43,9 +43,13 @@ fn map_qs(query_string_raw: Option<&str>) -> HashMap<&str, &str> { map } -async fn route_dispatcher(conn: &Conn, resp: &mut Response, meth: &Method, path: &str, params: &HashMap<&str, &str>) { +async fn route_dispatcher(pool: &Pool, resp: &mut Response, meth: &Method, path: &str, params: &HashMap<&str, &str>) { match (meth, path) { - (&Method::GET, routes::INVITE_JOIN) => invites::join_invite_code(conn, &mut resp, params).await, + (&Method::GET, routes::INVITE_JOIN) => { + if let Err(_) = invites::join_invite_code(pool, resp, params).await { + *resp.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + } + }, _ => { *resp.status_mut() = StatusCode::NOT_FOUND; } @@ -59,19 +63,18 @@ async fn main_responder(request: Request) -> Result, hyper: let path = request.uri().path(); let params = map_qs(request.uri().query()); - if let Ok(conn) = Conn::from_url(env::var("DATABASE_URL").unwrap()).await { + let pool = Pool::new(&env::var("DATABASE_URL").unwrap()); // some more information in the response would be great right about here - match auth::wall_entry(path, conn, ¶ms).await { - OpenAuth | Good => route_dispatcher(&conn, &mut response, &method, path, ¶ms).await, + if let Ok(auth_result) = auth::wall_entry(path, &pool, ¶ms).await { + match auth_result { + OpenAuth | Good => route_dispatcher(&pool, &mut response, &method, path, ¶ms).await, LimitPassed => *response.status_mut() = StatusCode::UNAUTHORIZED, NoKey => *response.status_mut() = StatusCode::UNAUTHORIZED, - InternalFailure => *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR } } else { *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; } - Ok(response) }