215 lines
6.3 KiB
Rust
215 lines
6.3 KiB
Rust
// Handlers for the base auth routes
|
|
use crate::{
|
|
DBConn, schema,
|
|
models::{
|
|
Invite,
|
|
User
|
|
}
|
|
};
|
|
|
|
use rocket::http::Status;
|
|
use rocket::response::{self, Responder, Response};
|
|
use rocket::request::{Form, Request};
|
|
use rocket_contrib::json::Json;
|
|
use diesel::{self, prelude::*};
|
|
use std::{error, fmt};
|
|
|
|
#[allow(dead_code)] // added because these fields are read through rocket, not directly; and rls keeps complainin
|
|
#[derive(FromForm)]
|
|
pub struct JoinParams {
|
|
code: u64,
|
|
name: String,
|
|
}
|
|
|
|
#[derive(FromForm, Deserialize)]
|
|
pub struct AuthKey {
|
|
id: u64,
|
|
secret: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct SessionToken {
|
|
pub data: String
|
|
}
|
|
pub type AuthResult<T, AuthErr> = std::result::Result<T, AuthErr>;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct AuthErr {
|
|
msg: &'static str,
|
|
status: u16,
|
|
}
|
|
|
|
impl fmt::Display for AuthErr {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
write!(f, "Authentication error")
|
|
}
|
|
}
|
|
|
|
impl error::Error for AuthErr {
|
|
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
|
None
|
|
}
|
|
}
|
|
|
|
impl<'r> Responder<'r> for AuthErr {
|
|
fn respond_to(self, _:&Request) -> response::Result<'r> {
|
|
Response::build()
|
|
.status(Status::InternalServerError)
|
|
.raw_header("db-error", self.msg)
|
|
.ok()
|
|
}
|
|
}
|
|
|
|
|
|
pub fn join(conn: DBConn, hashcode: u64, name: String) -> AuthResult<Json<User>, AuthErr>{
|
|
/*
|
|
* Requires <code:int> -> body
|
|
* Requires <name:string> -> body
|
|
* Struct JoinParams enforces this for us so if something is missing then rocket should 404
|
|
*/
|
|
use schema::invites::{self, dsl::*};
|
|
|
|
let diesel_result = invites
|
|
.filter(invites::dsl::id.eq(hashcode))
|
|
.first::<Invite>(&conn.0);
|
|
|
|
if let Ok(data) = diesel_result {
|
|
match data.uses {
|
|
1 ..= std::i32::MAX => {
|
|
let new_user = crate::users::create_new_user(name);
|
|
// At this point we don't really care about the return
|
|
let _ignore = diesel::update(invites.filter(invites::dsl::id.eq(hashcode)))
|
|
.set(uses.eq(data.uses - 1))
|
|
.execute(&conn.0);
|
|
|
|
Ok(Json(new_user))
|
|
}
|
|
// The invite has been used up and thus should be removed
|
|
std::i32::MIN ..= 0 => {
|
|
let _ = diesel::delete(invites.filter(invites::dsl::id.eq(data.id)))
|
|
.execute(&conn.0)
|
|
.expect("Could not delete invite");
|
|
|
|
Err(AuthErr{msg: "Invite expired", status: 404})
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
Err(AuthErr{msg: "Malformed request", status: 500})
|
|
}
|
|
}
|
|
|
|
fn confirm_user_api_access(conn: &MysqlConnection, user_id: u64, user_secret: &str) -> bool {
|
|
use schema::users::{self, dsl::*};
|
|
let result = users
|
|
.filter(id.eq(user_id))
|
|
.filter(secret.eq(user_secret))
|
|
.first::<User>(conn);
|
|
|
|
match result {
|
|
Ok(_data) => true,
|
|
Err(_e) => false
|
|
}
|
|
}
|
|
|
|
fn blind_remove_session(conn: &MysqlConnection, sesh_secret: &str) {
|
|
}
|
|
|
|
#[post("/login", data = "<api_key>")]
|
|
pub fn login(conn: DBConn, api_key: Form<AuthKey>) -> AuthResult<Json<SessionToken>, AuthErr>{
|
|
/*
|
|
* Session Tokens are used to key into a subset of online users
|
|
* This is what should make queries faster per instance as we'll have less data to sift through w/ diesel
|
|
*/
|
|
|
|
if confirm_user_api_access(&conn.0, api_key.id, &api_key.secret) {
|
|
// Dump any tokens from before and make a new one
|
|
blind_remove_session(&conn.0, &api_key.secret);
|
|
Ok(Json(SessionToken {
|
|
data: "skeleton code".to_string()
|
|
}))
|
|
}
|
|
else {
|
|
Err(AuthErr {
|
|
msg: "Nothing found",
|
|
status: 400
|
|
})
|
|
}
|
|
}
|
|
|
|
#[post("/leave", data = "<api_key>")]
|
|
pub fn leave(conn: DBConn, api_key: Form<AuthKey>) -> Status {
|
|
/*
|
|
* Basic removal of the user from our users table
|
|
*/
|
|
use crate::schema::users::dsl::*;
|
|
use crate::diesel::ExpressionMethods;
|
|
let _db_result = diesel::delete(users
|
|
.filter(id.eq(api_key.id))
|
|
.filter(secret.eq(api_key.secret.clone())))
|
|
.execute(&conn.0).unwrap();
|
|
|
|
|
|
Status::Accepted
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod auth_tests {
|
|
use crate::invites::static_rocket_route_info_for_use_invite;
|
|
use crate::schema;
|
|
use crate::models::{Invite};
|
|
use super::*;
|
|
use rocket::{self, local::Client};
|
|
use diesel::mysql::MysqlConnection;
|
|
use chrono::{Duration, Utc};
|
|
use rand::random;
|
|
use std::env;
|
|
use dotenv::dotenv;
|
|
use serde_json::Value;
|
|
|
|
fn setup_dotenv() -> Result<(), i32> {
|
|
match dotenv() {
|
|
Ok(_) => Ok(()),
|
|
Err(e) => panic!("`.env` could not be loaded: {:?}", e)
|
|
}
|
|
}
|
|
#[test]
|
|
fn feed_n_leave() {
|
|
// Create an invite in our db manually
|
|
// Use that invite to join
|
|
// Then leave using our neato /auth/leave route
|
|
if let Err(_denv) = setup_dotenv() {
|
|
panic!("env failed fukc")
|
|
}
|
|
let app = rocket::ignite()
|
|
.mount("/invite", routes![use_invite])
|
|
.attach(DBConn::fairing());
|
|
// First we create a new invite
|
|
let conn = MysqlConnection::establish(&env::var("DATABASE_URL").unwrap()).unwrap();
|
|
let dt = Utc::now() + Duration::minutes(30);
|
|
let invite = Invite {
|
|
id: random::<u64>(),
|
|
uses: 1,
|
|
expires: dt.timestamp() as u64,
|
|
};
|
|
let _ = diesel::insert_into(schema::invites::table)
|
|
.values(&invite)
|
|
.execute(&conn);
|
|
|
|
// use our new invite to "join" the server
|
|
let rocket_c = Client::new(app).expect("Invalid rocket instance");
|
|
let mut response = rocket_c.get(format!("/invite/join/{}/{}", invite.id, "billybob")).dispatch();
|
|
let body: String = response.body_string().unwrap();
|
|
let api_key: Value = serde_json::from_str(&body).unwrap();
|
|
|
|
let body_params = format!("id={}&secret={}", api_key["id"], api_key["secret"]);
|
|
println!("Parameters being sent {}", body_params);
|
|
let leave_response = rocket_c.get("/auth/leave")
|
|
.body(body_params)
|
|
.dispatch();
|
|
|
|
assert_eq!(leave_response.status(), Status::Ok);
|
|
println!("{}", body);
|
|
}
|
|
}
|