diff options
Diffstat (limited to 'server')
-rw-r--r-- | server/src/main.rs | 62 |
1 files changed, 42 insertions, 20 deletions
diff --git a/server/src/main.rs b/server/src/main.rs index 5e62213..c019a2a 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -15,7 +15,7 @@ use tungstenite::protocol::Message; const PUZZLE_ID_CHARSET: &[u8] = b"23456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ"; const PUZZLE_ID_LEN: usize = 7; -const MAX_PLAYERS: u16 = 20; +const MAX_PLAYERS: u32 = 20; const MAX_PIECES: usize = 1000; const ACTION_MOVE: u32 = 3; const ACTION_CONNECT: u32 = 4; @@ -28,7 +28,7 @@ fn generate_puzzle_id() -> [u8; PUZZLE_ID_LEN] { #[derive(Debug)] struct Server { // keep this in memory, since we want to reset it to 0 when the server restarts - player_counts: Mutex<HashMap<[u8; PUZZLE_ID_LEN], u16>>, + player_counts: Mutex<HashMap<[u8; PUZZLE_ID_LEN], u32>>, wikimedia_featured: Vec<String>, wikimedia_potd: RwLock<String>, database: tokio_postgres::Client, @@ -62,6 +62,28 @@ impl Server { .await .is_ok()) } + async fn increase_player_count(&self, id: [u8; PUZZLE_ID_LEN]) -> Result<()> { + let mut player_counts = self.player_counts.lock().await; + let entry = player_counts.entry(id).or_insert(0); + if *entry >= MAX_PLAYERS { + Err(Error::TooManyPlayers) + } else { + *entry += 1; + Ok(()) + } + } + async fn decrease_player_count(&self, id: [u8; PUZZLE_ID_LEN]) -> Result<()> { + let mut player_counts = self.player_counts.lock().await; + let std::collections::hash_map::Entry::Occupied(mut o) = player_counts.entry(id) else { + return Err(Error::BadPuzzleID); + }; + if *o.get() <= 1 { + o.remove(); + } else { + *o.get_mut() -= 1; + } + Ok(()) + } async fn set_puzzle_data( &self, id: [u8; PUZZLE_ID_LEN], @@ -186,7 +208,7 @@ impl std::fmt::Display for Error { match self { Error::BadPieceID => write!(f, "bad piece ID"), Error::BadPuzzleID => write!(f, "bad puzzle ID"), - Error::BadSyntax(s) => write!(f, "bad syntax: {s}"), + Error::BadSyntax(s) => write!(f, "{s}"), Error::ImageURLTooLong => write!(f, "image URL too long"), Error::TooManyPieces => write!(f, "too many pieces"), Error::NotJoined => write!(f, "haven't joined a puzzle"), @@ -326,16 +348,21 @@ async fn handle_websocket( .as_bytes() .try_into() .map_err(|_| Error::BadSyntax("bad join ID"))?; - let mut player_counts = server.player_counts.lock().await; - let entry = player_counts.entry(id).or_default(); - if *entry >= MAX_PLAYERS { - return Err(Error::TooManyPlayers); - } - *entry += 1; - drop(player_counts); // release lock + server.increase_player_count(id).await?; *puzzle_id = Some(id); let info = get_puzzle_info(server, &id).await?; ws.send(Message::Binary(info)).await?; + } else if let Some(id) = text.strip_prefix("rejoin ") { + let id = id + .as_bytes() + .try_into() + .map_err(|_| Error::BadSyntax("bad join ID"))?; + if puzzle_id.is_some() { + return Err(Error::BadSyntax("unexpected rejoin")); + } + server.increase_player_count(id).await?; + *puzzle_id = Some(id); + ws.send(Message::Text("rejoined".to_string())).await?; } else if text == "poll" { let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?; let PieceInfo { @@ -434,16 +461,11 @@ async fn handle_connection(server: &Server, conn: &mut tokio::net::TcpStream) -> ws.send(Message::Text(format!("error {e}"))).await?; }; if let Some(puzzle_id) = puzzle_id { - *server - .player_counts - .lock() - .await - .entry(puzzle_id) - .or_insert_with(|| { - eprintln!("negative player count??"); - // prevent underflow - 1 - }) -= 1; + if let Err(e) = server.decrease_player_count(puzzle_id).await { + eprintln!( + "unexpected error while decreasing player count for puzzle {puzzle_id:?}: {e}" + ); + } } status } |