+ JWT Authentication

* Server JWT's and Client JWT's built with seperate signature types
This commit is contained in:
shockrah
2021-03-30 12:24:10 -07:00
parent 5bbc57313f
commit 56e4e22b4c
6 changed files with 643 additions and 66 deletions

87
rtc-server/src/auth.rs Normal file
View File

@@ -0,0 +1,87 @@
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Serialize, Deserialize, de::DeserializeOwned};
use jsonwebtoken::{decode, DecodingKey};
use jsonwebtoken::{Validation, Algorithm};
use lazy_static::lazy_static;
use crate::peers::ConnectionType;
lazy_static! {
// This hmac is the key we use to sign ONLY API <-> WSS communications
static ref API_HMAC_SECRET: Vec<u8> = {
std::fs::read("wss-hmac.secret").expect("[WSS FATAL] Couldn't get WSS HMAC secret")
};
static ref USER_HMAC: Vec<u8> = {
std::fs::read("hmac.secret").expect("[WSS FATAL] no user hmac.secret found")
};
}
trait Claims {
/// Returns the unix timestamp in ms 0 if this field does not exist
fn time(&self) -> i64;
/// Returns user id of if one is present
fn sub(&self) -> Option<u64>;
}
#[derive(Deserialize, Serialize)]
struct APIClaim;
impl Claims for APIClaim {
fn time(&self) -> i64 { 0 }
fn sub(&self) -> Option<u64> { None }
}
#[derive(Debug, Serialize, Deserialize)]
struct UserClaim {
sub: u64, // user id
exp: i64, // expiry date
cookie: String, // unique cookie value
}
impl Claims for UserClaim {
fn time(&self) -> i64 { self.exp }
fn sub(&self) -> Option<u64> { Some(self.sub) }
}
fn verify_token<T>(token: &str) -> Option<ConnectionType> where
T: DeserializeOwned + Claims
{
let dk = DecodingKey::from_secret(&USER_HMAC);
let algo = Algorithm::HS512;
if let Ok(decoded) = decode::<T>(token, &dk, &Validation::new(algo)) {
let time = decoded.claims.time();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("[WSS UMMM] Couldn't get unix time")
.as_millis() as i64;
let active = now < time;
match active {
true => Some(ConnectionType::User),
false => None
}
}
// Only fallback to checking for a server connection since this basically never happens
// compared to the user connection branch
else {
let dk = DecodingKey::from_secret(&API_HMAC_SECRET);
if let Ok(_decoded) = decode::<T>(token, &dk, &Validation::new(algo)) {
Some(ConnectionType::Server)
} else {
None
}
}
}
pub fn verify(token: &str) -> Option<ConnectionType> {
match verify_token::<UserClaim>(token) {
Some(user) => Some(user),
None => verify_token::<APIClaim>(token)
}
}

View File

@@ -1,6 +0,0 @@
// The concept of channels here is a bit different from
struct VoiceCollection {
}
struct TextCollection {
}

View File

@@ -1,24 +1,29 @@
mod auth;
mod peers;
use std::sync::{Arc, Mutex};
use std::net::SocketAddr;
use std::collections::HashMap;
use futures::StreamExt;
use futures::channel::mpsc::UnboundedSender;
use futures::channel::mpsc::unbounded;
use futures::future;
use futures::pin_mut;
use futures::stream::TryStreamExt;
use tokio_tungstenite as tokio_ws;
use tokio_ws::tungstenite::Message;
use tokio_ws::tungstenite::handshake::server::ErrorResponse;
use tokio_ws::tungstenite::http::{Response, Request};
use tokio::net::{TcpListener, TcpStream};
use clap::{Arg, App};
type Tx = UnboundedSender<Message>;
type Peers = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
use peers::{Peer, Channel, PeerMap, ConnectionType};
macro_rules! header_err {
($s:literal) => {
Err(ErrorResponse::new(Some($s.to_string())))
}
}
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
@@ -33,73 +38,96 @@ async fn main() -> Result<(), std::io::Error> {
.takes_value(true))
.get_matches();
let addr = match matches.value_of("port") {
let connections: PeerMap = Arc::new(Mutex::new(HashMap::new()));
// Websocket server initialization
let wsaddr = match matches.value_of("port") {
Some(pval) => {
let port = pval.parse::<u16>().unwrap();
format!("127.0.0.1:{}", port)
},
None => format!("127.0.0.1:5648")
};
let wssocket = TcpListener::bind(&wsaddr).await?;
println!("[INFO] WSS Listening on {}", wsaddr);
let socket = TcpListener::bind(&addr).await?;
println!("[INFO] Listening on {}", addr);
let peers = Peers::new(Mutex::new(HashMap::new()));
while let Ok((stream, _)) = socket.accept().await {
tokio::spawn(handle_connections(stream, peers.clone()));
while let Ok((stream, _)) = wssocket.accept().await {
tokio::spawn(handle_connections(stream, connections.clone()));
}
Ok(())
}
fn header_validation(request: &Request<()>, response: Response<()>) -> Result<Response<()>, ErrorResponse> {
// validate that the required headers are presetn
// Required headers: Subscribe-Channel & Jwt-Token
let valid_channels = ["/text", "/voice"];
let path = request.uri();
for (hdr, val) in request.headers().iter() {
println!("{:?} -> {:?}", hdr, val);
}
Ok(response)
}
async fn handle_connections(stream: TcpStream, peers: Peers) {
async fn handle_connections(stream: TcpStream, peermap: PeerMap) {
let addr = stream.peer_addr().expect("[ERROR] Peer address not found");
// NOTE: this call underneath actually blocks which blows but it will do for now
// NOTE: find some kind of way of doing async callbacks here so that we can scale
let ws_stream = tokio_ws::accept_hdr_async(stream, header_validation)
.await.expect("[ERROR] Could not finish handshake");
println!("[INFO] New websocket connection: {}", addr.ip());
let mut domain: Option<Channel> = None;
let ws_stream = tokio_ws::accept_hdr_async(stream,
|request:&Request<()>, response:Response<()>| -> Result<Response<()>, ErrorResponse> {
domain = match request.uri().path() {
"/voice" => Some(Channel::Voice),
"/text" => Some(Channel::Text),
_ => None
};
println!("{:?}", request.headers());
let entry = request.headers()
.iter().find(|(name, _)| name.as_str() == "jwt");
if let Some((_, jwt)) = entry {
match auth::verify(jwt.to_str().expect("Unable to convert header to str")) {
Some(_conn_type) => Ok(response),
None => panic!("[WSS] Unable to verify connection")
}
} else {
header_err!("JWT not found in header")
}
}).await;
let ws_stream = match ws_stream {
Ok(stream) => {
println!("[WSS INFO] New connection established");
stream
}
Err(e) => panic!(format!("[WSS ERROR] Could not finish handshake: {}", e))
};
let (tx, rx) = unbounded(); // split the peer's write and read streams
peers.lock().unwrap().insert(addr, tx);
let peer = Peer::new(tx, domain); // safe as the handshake fails with None domain
println!("{:?}", peer);
// Add the new peer
peermap.lock().unwrap().insert(addr, peer);
let (write, read) = ws_stream.split();
let broadcast_incoming = read.try_for_each(|msg| {
println!("Got message from {}: {}", addr, msg.to_text().unwrap());
// hold a ref to the peers map so that we can iterate through them
// OPTI NOTE: because this runs next to the REST API it makes sense to avoid
// doing anything if the user-level peers try do anything but listen
let peers = peers.lock().unwrap();
let peers = peermap.lock().unwrap();
// TODO: restructure this so that the server connection
// never gets rans over to avoid this meme of a .collect
// collect everyone except the server connection
let recipients = peers
.iter().filter(|(p_addr, _)| p_addr != &&addr ) // avoid echo back to sender
.map(|(_, ws_sink)| ws_sink);
.iter().filter(|(p_addr, meta)| p_addr != &&addr || meta.conn == ConnectionType::Server)
.map(|(_, sink)| sink);
for rec in recipients {
rec.unbounded_send(msg.clone()).expect("[WARN] Unable to send message");
for recv in recipients {
println!("{:?}", recv);
recv.try_send_text(msg.clone());
}
future::ok(())
});
// magic
// magic
let forward = rx.map(Ok).forward(write);
pin_mut!(broadcast_incoming, forward);
future::select(broadcast_incoming, forward).await;
println!("{} dc'd", &addr);
peers.lock().unwrap().remove(&addr);
peermap.lock().unwrap().remove(&addr);
}

51
rtc-server/src/peers.rs Normal file
View File

@@ -0,0 +1,51 @@
use std::sync::{Arc, Mutex};
use std::collections::HashMap;
use std::net::SocketAddr;
use futures::channel::mpsc::UnboundedSender;
use tokio_tungstenite as tokio_ws;
use tokio_ws::tungstenite::Message;
pub type PeerMap = Arc<Mutex<HashMap<SocketAddr, Peer>>>;
type Tx = UnboundedSender<Message>;
#[derive(Debug, PartialEq)]
pub enum ConnectionType {
User,
Server
}
#[derive(Clone, Debug, PartialEq)]
pub enum Channel {
Text,
Voice,
}
#[derive(Debug)]
pub struct Peer {
transfer: Tx,
channel: Option<Channel>,
pub conn: ConnectionType
}
impl Peer {
pub fn new(transfer: Tx, channel: Option<Channel>) -> Self {
Self {
transfer,
channel: channel.clone(),
conn: match channel {
Some(_) => ConnectionType::User,
_ => ConnectionType::Server
}
}
}
pub fn try_send_text(&self, msg: Message) {
if self.channel == Some(Channel::Text) {
self.transfer.unbounded_send(msg)
.expect("[WSS-COMM-ERROR] Unable to notify peer");
}
}
}