From b9fd18dc538b59d5a901057466066e2e62c625c7 Mon Sep 17 00:00:00 2001 From: pommicket Date: Sun, 11 Aug 2024 23:46:21 -0400 Subject: enforce player limit and some other things --- server/src/main.rs | 359 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 236 insertions(+), 123 deletions(-) (limited to 'server/src') diff --git a/server/src/main.rs b/server/src/main.rs index c289c89..f695910 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,16 +1,18 @@ -use anyhow::anyhow; use futures_util::{SinkExt, StreamExt}; use rand::seq::SliceRandom; use rand::Rng; +use std::collections::HashMap; use std::io::prelude::*; use std::net::SocketAddr; +use std::sync::Arc; use std::time::{Duration, SystemTime}; use tokio::io::AsyncWriteExt; -use tokio::sync::RwLock; +use tokio::sync::{Mutex, RwLock}; use tungstenite::protocol::Message; const PUZZLE_ID_CHARSET: &[u8] = b"23456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ"; const PUZZLE_ID_LEN: usize = 7; +const MAX_PLAYERS: u16 = 20; fn generate_puzzle_id() -> [u8; PUZZLE_ID_LEN] { let mut rng = rand::thread_rng(); @@ -22,51 +24,92 @@ struct Server { puzzles: sled::Tree, pieces: sled::Tree, connectivity: sled::Tree, + // keep this in memory, since we want to reset it to 0 when the server restarts + player_counts: Mutex>, wikimedia_featured: Vec, wikimedia_potd: RwLock, } -fn get_puzzle_info(server: &Server, id: &[u8]) -> anyhow::Result> { +#[derive(Debug)] +enum Error { + Tungstenite(tungstenite::Error), + Sled(sled::Error), + IO(std::io::Error), + UTF8(std::str::Utf8Error), + BadPuzzleID, + BadPieceID, + BadSyntax, + ImageURLTooLong, + TooManyPieces, + TooManyPlayers, + NotJoined, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::BadPieceID => write!(f, "bad piece ID"), + Error::BadPuzzleID => write!(f, "bad puzzle ID"), + Error::BadSyntax => write!(f, "bad syntax"), + Error::ImageURLTooLong => write!(f, "image URL too long"), + Error::TooManyPieces => write!(f, "too many pieces"), + Error::NotJoined => write!(f, "haven't joined a puzzle"), + Error::TooManyPlayers => write!(f, "too many players"), + Error::Sled(e) => write!(f, "{e}"), + Error::IO(e) => write!(f, "{e}"), + Error::UTF8(e) => write!(f, "{e}"), + Error::Tungstenite(e) => write!(f, "{e}"), + } + } +} + +impl From for Error { + fn from(value: sled::Error) -> Self { + Self::Sled(value) + } +} + +impl From for Error { + fn from(value: tungstenite::Error) -> Self { + Self::Tungstenite(value) + } +} +impl From for Error { + fn from(value: std::io::Error) -> Self { + Self::IO(value) + } +} +impl From for Error { + fn from(value: std::str::Utf8Error) -> Self { + Self::UTF8(value) + } +} + +type Result = std::result::Result; + +fn get_puzzle_info(server: &Server, id: &[u8]) -> Result> { if id.len() != PUZZLE_ID_LEN { - return Err(anyhow!("bad puzzle ID")); + return Err(Error::BadPuzzleID); } let mut data = vec![1, 0, 0, 0, 0, 0, 0, 0]; // opcode + padding - let puzzle = server - .puzzles - .get(id)? - .ok_or_else(|| anyhow!("bad puzzle ID"))?; + let puzzle = server.puzzles.get(id)?.ok_or(Error::BadPuzzleID)?; data.extend_from_slice(&puzzle); while data.len() % 8 != 0 { // padding data.push(0); } - let pieces = server - .pieces - .get(id)? - .ok_or_else(|| anyhow!("bad puzzle ID"))?; + let pieces = server.pieces.get(id)?.ok_or(Error::BadPuzzleID)?; data.extend_from_slice(&pieces); - let connectivity = server - .connectivity - .get(id)? - .ok_or_else(|| anyhow!("bad puzzle ID"))?; + let connectivity = server.connectivity.get(id)?.ok_or(Error::BadPuzzleID)?; data.extend_from_slice(&connectivity); Ok(data) } -async fn handle_connection( +async fn handle_websocket( server: &Server, - conn: &mut tokio::net::TcpStream, -) -> anyhow::Result<()> { - let mut ws = tokio_tungstenite::accept_async_with_config( - conn, - Some(tungstenite::protocol::WebSocketConfig { - max_message_size: Some(65536), - max_frame_size: Some(65536), - ..Default::default() - }), - ) - .await?; - let mut puzzle_id = None; + puzzle_id: &mut Option<[u8; PUZZLE_ID_LEN]>, + ws: &mut tokio_tungstenite::WebSocketStream<&mut tokio::net::TcpStream>, +) -> Result<()> { while let Some(message) = ws.next().await { let message = message?; if matches!(message, Message::Close(_)) { @@ -76,14 +119,22 @@ async fn handle_connection( let text = text.trim(); if let Some(dimensions) = text.strip_prefix("new ") { let mut parts = dimensions.split(' '); - let width: u8 = parts.next().ok_or_else(|| anyhow!("no width"))?.parse()?; - let height: u8 = parts.next().ok_or_else(|| anyhow!("no height"))?.parse()?; - let url: &str = parts.next().ok_or_else(|| anyhow!("no url"))?; + let width: u8 = parts + .next() + .ok_or(Error::BadSyntax)? + .parse() + .map_err(|_| Error::BadSyntax)?; + let height: u8 = parts + .next() + .ok_or(Error::BadSyntax)? + .parse() + .map_err(|_| Error::BadSyntax)?; + let url: &str = parts.next().ok_or(Error::BadSyntax)?; if url.len() > 255 { - return Err(anyhow!("image URL too long")); + return Err(Error::ImageURLTooLong); } if (width as u16) * (height as u16) > 1000 { - return Err(anyhow!("too many pieces")); + return Err(Error::TooManyPieces); } let mut puzzle_data = vec![width, height]; let timestamp: u64 = SystemTime::now() @@ -120,7 +171,7 @@ async fn handle_connection( } } drop(puzzle_data); // should be empty now - puzzle_id = Some(id); + *puzzle_id = Some(id); let pieces_data: Box<[u8]>; { let mut rng = rand::thread_rng(); @@ -154,17 +205,28 @@ async fn handle_connection( connectivity_data.extend(i.to_le_bytes()); } server.connectivity.insert(id, connectivity_data)?; - ws.send(Message::Text(format!("id: {}", std::str::from_utf8(&id)?))) - .await?; + server.player_counts.lock().await.insert(id, 1); + ws.send(Message::Text(format!( + "id: {}", + std::str::from_utf8(&id).expect("puzzle ID has bad utf-8???") + ))) + .await?; let info = get_puzzle_info(server, &id)?; ws.send(Message::Binary(info)).await?; } else if let Some(id) = text.strip_prefix("join ") { - let id = id.as_bytes().try_into()?; - puzzle_id = Some(id); + let id = id.as_bytes().try_into().map_err(|_| Error::BadSyntax)?; + let mut player_counts = server.player_counts.lock().await; + let entry = player_counts.entry(id).or_default(); + if *entry >= MAX_PLAYERS { + return Err(Error::TooManyPlayers); + } + *entry += 1; + drop(player_counts); // release lock + *puzzle_id = Some(id); let info = get_puzzle_info(server, &id)?; ws.send(Message::Binary(info)).await?; } else if text.starts_with("move ") { - let puzzle_id = puzzle_id.ok_or_else(|| anyhow!("move without puzzle ID"))?; + let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?; #[derive(Clone, Copy)] struct Motion { piece: usize, @@ -175,93 +237,102 @@ async fn handle_connection( for line in text.split('\n') { let mut parts = line.split(' '); parts.next(); // skip "move" - let piece: usize = - parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?; - let x: f32 = parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?; - let y: f32 = parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?; + let piece: usize = parts + .next() + .ok_or(Error::BadSyntax)? + .parse() + .map_err(|_| Error::BadSyntax)?; + let x: f32 = parts + .next() + .ok_or(Error::BadSyntax)? + .parse() + .map_err(|_| Error::BadSyntax)?; + let y: f32 = parts + .next() + .ok_or(Error::BadSyntax)? + .parse() + .map_err(|_| Error::BadSyntax)?; motions.push(Motion { piece, x, y }); } - loop { - let curr_pieces = server - .pieces - .get(puzzle_id)? - .ok_or_else(|| anyhow!("bad puzzle ID"))?; - let mut new_pieces = curr_pieces.to_vec(); - for Motion { piece, x, y } in motions.iter().copied() { - new_pieces - .get_mut(8 * piece..8 * piece + 4) - .ok_or_else(|| anyhow!("bad piece ID"))? - .copy_from_slice(&x.to_le_bytes()); - new_pieces - .get_mut(8 * piece + 4..8 * piece + 8) - .ok_or_else(|| anyhow!("bad piece ID"))? - .copy_from_slice(&y.to_le_bytes()); - } - if server - .pieces - .compare_and_swap(puzzle_id, Some(curr_pieces), Some(new_pieces))? - .is_ok() - { - break; - } - tokio::time::sleep(std::time::Duration::from_millis(1)).await; // yield maybe (don't let contention hog resources) + let mut error = None; + server + .pieces + .fetch_and_update(puzzle_id, |curr_pieces: Option<&[u8]>| { + let Some(curr_pieces) = curr_pieces else { + error = Some(Error::BadPuzzleID); + return None; + }; + let mut new_pieces = curr_pieces.to_vec(); + for Motion { piece, x, y } in motions.iter().copied() { + let Some(slice) = new_pieces.get_mut(8 * piece..8 * piece + 8) else { + error = Some(Error::BadPieceID); + break; + }; + slice[0..4].copy_from_slice(&x.to_le_bytes()); + slice[4..8].copy_from_slice(&y.to_le_bytes()); + } + Some(new_pieces) + })?; + if let Some(error) = error { + return Err(error); } ws.send(Message::Text("ack".to_string())).await?; } else if let Some(data) = text.strip_prefix("connect ") { let mut parts = data.split(' '); - let puzzle_id = puzzle_id.ok_or_else(|| anyhow!("connect without puzzle ID"))?; - let piece1: usize = parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?; - let piece2: usize = parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?; - loop { - let curr_connectivity = server - .connectivity - .get(puzzle_id)? - .ok_or_else(|| anyhow!("bad puzzle ID"))?; - let mut new_connectivity = curr_connectivity.to_vec(); - if piece1 >= curr_connectivity.len() / 2 - || piece2 >= curr_connectivity.len() / 2 - { - return Err(anyhow!("bad piece ID")); - } - let piece2_group = u16::from_le_bytes([ - curr_connectivity[piece2 * 2], - curr_connectivity[piece2 * 2 + 1], - ]); - let a = curr_connectivity[piece1 * 2]; - let b = curr_connectivity[piece1 * 2 + 1]; - for piece in 0..curr_connectivity.len() / 2 { - let piece_group = u16::from_le_bytes([ - curr_connectivity[piece * 2], - curr_connectivity[piece * 2 + 1], + let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?; + let piece1: usize = parts + .next() + .ok_or(Error::BadSyntax)? + .parse() + .map_err(|_| Error::BadSyntax)?; + let piece2: usize = parts + .next() + .ok_or(Error::BadSyntax)? + .parse() + .map_err(|_| Error::BadSyntax)?; + let mut error = None; + server + .connectivity + .fetch_and_update(puzzle_id, |curr_connectivity| { + let Some(curr_connectivity) = curr_connectivity else { + error = Some(Error::BadPuzzleID); + return None; + }; + let mut new_connectivity = curr_connectivity.to_vec(); + if piece1 >= curr_connectivity.len() / 2 + || piece2 >= curr_connectivity.len() / 2 + { + error = Some(Error::BadPieceID); + return Some(new_connectivity); + } + let piece2_group = u16::from_le_bytes([ + curr_connectivity[piece2 * 2], + curr_connectivity[piece2 * 2 + 1], ]); - if piece_group == piece2_group { - new_connectivity[piece * 2] = a; - new_connectivity[piece * 2 + 1] = b; + let a = curr_connectivity[piece1 * 2]; + let b = curr_connectivity[piece1 * 2 + 1]; + for piece in 0..curr_connectivity.len() / 2 { + let piece_group = u16::from_le_bytes([ + curr_connectivity[piece * 2], + curr_connectivity[piece * 2 + 1], + ]); + if piece_group == piece2_group { + new_connectivity[piece * 2] = a; + new_connectivity[piece * 2 + 1] = b; + } } - } - if server - .connectivity - .compare_and_swap( - puzzle_id, - Some(curr_connectivity), - Some(new_connectivity), - )? - .is_ok() - { - break; - } - tokio::time::sleep(std::time::Duration::from_millis(1)).await; // yield maybe (don't let contention hog resources) + Some(new_connectivity) + })?; + if let Some(error) = error { + return Err(error); } } else if text == "poll" { - let puzzle_id = puzzle_id.ok_or_else(|| anyhow!("poll without puzzle ID"))?; - let pieces = server - .pieces - .get(puzzle_id)? - .ok_or_else(|| anyhow!("bad puzzle ID"))?; + let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?; + let pieces = server.pieces.get(puzzle_id)?.ok_or(Error::BadPuzzleID)?; let connectivity = server .connectivity .get(puzzle_id)? - .ok_or_else(|| anyhow!("bad puzzle ID"))?; + .ok_or(Error::BadPuzzleID)?; let mut data = vec![2, 0, 0, 0, 0, 0, 0, 0]; // opcode / version number + padding data.extend_from_slice(&pieces); data.extend_from_slice(&connectivity); @@ -285,18 +356,51 @@ async fn handle_connection( Ok(()) } +async fn handle_connection(server: &Server, conn: &mut tokio::net::TcpStream) -> Result<()> { + let mut puzzle_id = None; + let mut ws = tokio_tungstenite::accept_async_with_config( + conn, + Some(tungstenite::protocol::WebSocketConfig { + max_message_size: Some(65536), + max_frame_size: Some(65536), + ..Default::default() + }), + ) + .await?; + let status = handle_websocket(server, &mut puzzle_id, &mut ws).await; + if let Err(e) = &status { + ws.send(Message::Text(format!("error {e}"))).await?; + }; + if let Some(puzzle_id) = puzzle_id { + *server + .player_counts + .lock() + .await + .entry(puzzle_id) + .or_insert_with(|| { + eprintln!("negative player count??"); + // prevent underflow + 1 + }) -= 1; + } + status +} + fn read_to_lines(path: &str) -> std::io::Result> { let file = std::fs::File::open(path)?; let reader = std::io::BufReader::new(file); reader.lines().collect() } -async fn try_get_potd() -> anyhow::Result { +async fn try_get_potd() -> Result { let output = tokio::process::Command::new("python3") .arg("potd.py") .output() .await?; - Ok(String::from_utf8(output.stdout)?.trim().to_string()) + Ok(String::from_utf8(output.stdout) + .map_err(|e| e.utf8_error())? + .trim() + .to_string()) } async fn get_potd() -> String { match try_get_potd().await { @@ -320,8 +424,7 @@ async fn main() { } }; let start_time = SystemTime::now(); - // leak this since we need all threads to be able to access this - let server: &'static Server = Box::leak(Box::new({ + let server_arc: Arc = Arc::new({ let wikimedia_featured = read_to_lines("featuredpictures.txt").expect("Couldn't read featuredpictures.txt"); let db = sled::open("database.sled").expect("error opening database"); @@ -334,12 +437,15 @@ async fn main() { Server { puzzles, pieces, + player_counts: Mutex::new(HashMap::new()), connectivity, wikimedia_potd: RwLock::new(potd), wikimedia_featured, } - })); + }); + let server_arc_clone = server_arc.clone(); tokio::task::spawn(async move { + let server: &Server = server_arc_clone.as_ref(); fn next_day(t: SystemTime) -> SystemTime { let day = 60 * 60 * 24; let dt = t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs(); @@ -354,7 +460,9 @@ async fn main() { last_time = SystemTime::now(); } }); - tokio::task::spawn(async { + let server_arc_clone = server_arc.clone(); + tokio::task::spawn(async move { + let server: &Server = server_arc_clone.as_ref(); loop { // TODO : sweep let now = SystemTime::now(); @@ -376,15 +484,18 @@ async fn main() { server .puzzles .remove(&key) - .expect("sweep failed to delete entry"); + .expect("sweep failed to delete puzzle"); server .pieces .remove(&key) - .expect("sweep failed to delete entry"); + .expect("sweep failed to delete pieces"); server .connectivity .remove(&key) - .expect("sweep failed to delete entry"); + .expect("sweep failed to delete connectivity"); + if let Some(key) = <[u8; PUZZLE_ID_LEN]>::try_from(&key[..]).ok() { + server.player_counts.lock().await.remove(&key); + } } tokio::time::sleep(std::time::Duration::from_secs(3600)).await; } @@ -397,7 +508,9 @@ async fn main() { continue; } }; + let server_arc_clone = server_arc.clone(); tokio::task::spawn(async move { + let server: &Server = server_arc_clone.as_ref(); match handle_connection(server, &mut stream).await { Ok(()) => {} Err(e) => { -- cgit v1.2.3