+ JWT Authentication
* Server JWT's and Client JWT's built with seperate signature types
This commit is contained in:
87
rtc-server/src/auth.rs
Normal file
87
rtc-server/src/auth.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use serde::{Serialize, Deserialize, de::DeserializeOwned};
|
||||
use jsonwebtoken::{decode, DecodingKey};
|
||||
use jsonwebtoken::{Validation, Algorithm};
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use crate::peers::ConnectionType;
|
||||
|
||||
lazy_static! {
|
||||
// This hmac is the key we use to sign ONLY API <-> WSS communications
|
||||
static ref API_HMAC_SECRET: Vec<u8> = {
|
||||
std::fs::read("wss-hmac.secret").expect("[WSS FATAL] Couldn't get WSS HMAC secret")
|
||||
};
|
||||
|
||||
static ref USER_HMAC: Vec<u8> = {
|
||||
std::fs::read("hmac.secret").expect("[WSS FATAL] no user hmac.secret found")
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
trait Claims {
|
||||
/// Returns the unix timestamp in ms 0 if this field does not exist
|
||||
fn time(&self) -> i64;
|
||||
/// Returns user id of if one is present
|
||||
fn sub(&self) -> Option<u64>;
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct APIClaim;
|
||||
|
||||
impl Claims for APIClaim {
|
||||
fn time(&self) -> i64 { 0 }
|
||||
fn sub(&self) -> Option<u64> { None }
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct UserClaim {
|
||||
sub: u64, // user id
|
||||
exp: i64, // expiry date
|
||||
cookie: String, // unique cookie value
|
||||
}
|
||||
|
||||
impl Claims for UserClaim {
|
||||
fn time(&self) -> i64 { self.exp }
|
||||
fn sub(&self) -> Option<u64> { Some(self.sub) }
|
||||
}
|
||||
|
||||
|
||||
|
||||
fn verify_token<T>(token: &str) -> Option<ConnectionType> where
|
||||
T: DeserializeOwned + Claims
|
||||
{
|
||||
let dk = DecodingKey::from_secret(&USER_HMAC);
|
||||
let algo = Algorithm::HS512;
|
||||
|
||||
if let Ok(decoded) = decode::<T>(token, &dk, &Validation::new(algo)) {
|
||||
let time = decoded.claims.time();
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("[WSS UMMM] Couldn't get unix time")
|
||||
.as_millis() as i64;
|
||||
let active = now < time;
|
||||
match active {
|
||||
true => Some(ConnectionType::User),
|
||||
false => None
|
||||
}
|
||||
}
|
||||
// Only fallback to checking for a server connection since this basically never happens
|
||||
// compared to the user connection branch
|
||||
else {
|
||||
let dk = DecodingKey::from_secret(&API_HMAC_SECRET);
|
||||
if let Ok(_decoded) = decode::<T>(token, &dk, &Validation::new(algo)) {
|
||||
Some(ConnectionType::Server)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn verify(token: &str) -> Option<ConnectionType> {
|
||||
match verify_token::<UserClaim>(token) {
|
||||
Some(user) => Some(user),
|
||||
None => verify_token::<APIClaim>(token)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
// The concept of channels here is a bit different from
|
||||
|
||||
struct VoiceCollection {
|
||||
}
|
||||
struct TextCollection {
|
||||
}
|
||||
@@ -1,24 +1,29 @@
|
||||
mod auth;
|
||||
mod peers;
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::net::SocketAddr;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use futures::StreamExt;
|
||||
use futures::channel::mpsc::UnboundedSender;
|
||||
use futures::channel::mpsc::unbounded;
|
||||
use futures::future;
|
||||
use futures::pin_mut;
|
||||
use futures::stream::TryStreamExt;
|
||||
|
||||
use tokio_tungstenite as tokio_ws;
|
||||
use tokio_ws::tungstenite::Message;
|
||||
use tokio_ws::tungstenite::handshake::server::ErrorResponse;
|
||||
use tokio_ws::tungstenite::http::{Response, Request};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use clap::{Arg, App};
|
||||
|
||||
type Tx = UnboundedSender<Message>;
|
||||
type Peers = Arc<Mutex<HashMap<SocketAddr, Tx>>>;
|
||||
use peers::{Peer, Channel, PeerMap, ConnectionType};
|
||||
|
||||
macro_rules! header_err {
|
||||
($s:literal) => {
|
||||
Err(ErrorResponse::new(Some($s.to_string())))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), std::io::Error> {
|
||||
@@ -33,73 +38,96 @@ async fn main() -> Result<(), std::io::Error> {
|
||||
.takes_value(true))
|
||||
.get_matches();
|
||||
|
||||
let addr = match matches.value_of("port") {
|
||||
let connections: PeerMap = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
// Websocket server initialization
|
||||
let wsaddr = match matches.value_of("port") {
|
||||
Some(pval) => {
|
||||
let port = pval.parse::<u16>().unwrap();
|
||||
format!("127.0.0.1:{}", port)
|
||||
},
|
||||
None => format!("127.0.0.1:5648")
|
||||
};
|
||||
let wssocket = TcpListener::bind(&wsaddr).await?;
|
||||
println!("[INFO] WSS Listening on {}", wsaddr);
|
||||
|
||||
let socket = TcpListener::bind(&addr).await?;
|
||||
|
||||
println!("[INFO] Listening on {}", addr);
|
||||
|
||||
let peers = Peers::new(Mutex::new(HashMap::new()));
|
||||
|
||||
while let Ok((stream, _)) = socket.accept().await {
|
||||
tokio::spawn(handle_connections(stream, peers.clone()));
|
||||
while let Ok((stream, _)) = wssocket.accept().await {
|
||||
tokio::spawn(handle_connections(stream, connections.clone()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
fn header_validation(request: &Request<()>, response: Response<()>) -> Result<Response<()>, ErrorResponse> {
|
||||
// validate that the required headers are presetn
|
||||
// Required headers: Subscribe-Channel & Jwt-Token
|
||||
let valid_channels = ["/text", "/voice"];
|
||||
let path = request.uri();
|
||||
|
||||
for (hdr, val) in request.headers().iter() {
|
||||
println!("{:?} -> {:?}", hdr, val);
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn handle_connections(stream: TcpStream, peers: Peers) {
|
||||
async fn handle_connections(stream: TcpStream, peermap: PeerMap) {
|
||||
let addr = stream.peer_addr().expect("[ERROR] Peer address not found");
|
||||
|
||||
// NOTE: this call underneath actually blocks which blows but it will do for now
|
||||
// NOTE: find some kind of way of doing async callbacks here so that we can scale
|
||||
let ws_stream = tokio_ws::accept_hdr_async(stream, header_validation)
|
||||
.await.expect("[ERROR] Could not finish handshake");
|
||||
println!("[INFO] New websocket connection: {}", addr.ip());
|
||||
|
||||
let mut domain: Option<Channel> = None;
|
||||
let ws_stream = tokio_ws::accept_hdr_async(stream,
|
||||
|request:&Request<()>, response:Response<()>| -> Result<Response<()>, ErrorResponse> {
|
||||
|
||||
domain = match request.uri().path() {
|
||||
"/voice" => Some(Channel::Voice),
|
||||
"/text" => Some(Channel::Text),
|
||||
_ => None
|
||||
};
|
||||
|
||||
println!("{:?}", request.headers());
|
||||
let entry = request.headers()
|
||||
.iter().find(|(name, _)| name.as_str() == "jwt");
|
||||
|
||||
if let Some((_, jwt)) = entry {
|
||||
match auth::verify(jwt.to_str().expect("Unable to convert header to str")) {
|
||||
Some(_conn_type) => Ok(response),
|
||||
None => panic!("[WSS] Unable to verify connection")
|
||||
}
|
||||
} else {
|
||||
header_err!("JWT not found in header")
|
||||
}
|
||||
}).await;
|
||||
|
||||
let ws_stream = match ws_stream {
|
||||
Ok(stream) => {
|
||||
println!("[WSS INFO] New connection established");
|
||||
stream
|
||||
}
|
||||
Err(e) => panic!(format!("[WSS ERROR] Could not finish handshake: {}", e))
|
||||
};
|
||||
|
||||
|
||||
let (tx, rx) = unbounded(); // split the peer's write and read streams
|
||||
peers.lock().unwrap().insert(addr, tx);
|
||||
let peer = Peer::new(tx, domain); // safe as the handshake fails with None domain
|
||||
println!("{:?}", peer);
|
||||
|
||||
// Add the new peer
|
||||
peermap.lock().unwrap().insert(addr, peer);
|
||||
|
||||
let (write, read) = ws_stream.split();
|
||||
|
||||
let broadcast_incoming = read.try_for_each(|msg| {
|
||||
println!("Got message from {}: {}", addr, msg.to_text().unwrap());
|
||||
// hold a ref to the peers map so that we can iterate through them
|
||||
// OPTI NOTE: because this runs next to the REST API it makes sense to avoid
|
||||
// doing anything if the user-level peers try do anything but listen
|
||||
let peers = peers.lock().unwrap();
|
||||
let peers = peermap.lock().unwrap();
|
||||
// TODO: restructure this so that the server connection
|
||||
// never gets rans over to avoid this meme of a .collect
|
||||
// collect everyone except the server connection
|
||||
let recipients = peers
|
||||
.iter().filter(|(p_addr, _)| p_addr != &&addr ) // avoid echo back to sender
|
||||
.map(|(_, ws_sink)| ws_sink);
|
||||
.iter().filter(|(p_addr, meta)| p_addr != &&addr || meta.conn == ConnectionType::Server)
|
||||
.map(|(_, sink)| sink);
|
||||
|
||||
for rec in recipients {
|
||||
rec.unbounded_send(msg.clone()).expect("[WARN] Unable to send message");
|
||||
for recv in recipients {
|
||||
println!("{:?}", recv);
|
||||
recv.try_send_text(msg.clone());
|
||||
}
|
||||
future::ok(())
|
||||
});
|
||||
// magic
|
||||
|
||||
// magic
|
||||
let forward = rx.map(Ok).forward(write);
|
||||
pin_mut!(broadcast_incoming, forward);
|
||||
future::select(broadcast_incoming, forward).await;
|
||||
println!("{} dc'd", &addr);
|
||||
peers.lock().unwrap().remove(&addr);
|
||||
peermap.lock().unwrap().remove(&addr);
|
||||
}
|
||||
|
||||
51
rtc-server/src/peers.rs
Normal file
51
rtc-server/src/peers.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use futures::channel::mpsc::UnboundedSender;
|
||||
|
||||
use tokio_tungstenite as tokio_ws;
|
||||
use tokio_ws::tungstenite::Message;
|
||||
|
||||
|
||||
pub type PeerMap = Arc<Mutex<HashMap<SocketAddr, Peer>>>;
|
||||
type Tx = UnboundedSender<Message>;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum ConnectionType {
|
||||
User,
|
||||
Server
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Channel {
|
||||
Text,
|
||||
Voice,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Peer {
|
||||
transfer: Tx,
|
||||
channel: Option<Channel>,
|
||||
pub conn: ConnectionType
|
||||
}
|
||||
|
||||
impl Peer {
|
||||
pub fn new(transfer: Tx, channel: Option<Channel>) -> Self {
|
||||
Self {
|
||||
transfer,
|
||||
channel: channel.clone(),
|
||||
conn: match channel {
|
||||
Some(_) => ConnectionType::User,
|
||||
_ => ConnectionType::Server
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_send_text(&self, msg: Message) {
|
||||
if self.channel == Some(Channel::Text) {
|
||||
self.transfer.unbounded_send(msg)
|
||||
.expect("[WSS-COMM-ERROR] Unable to notify peer");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user