summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
Diffstat (limited to 'server')
-rw-r--r--server/src/main.rs62
1 files changed, 42 insertions, 20 deletions
diff --git a/server/src/main.rs b/server/src/main.rs
index 5e62213..c019a2a 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -15,7 +15,7 @@ use tungstenite::protocol::Message;
const PUZZLE_ID_CHARSET: &[u8] = b"23456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ";
const PUZZLE_ID_LEN: usize = 7;
-const MAX_PLAYERS: u16 = 20;
+const MAX_PLAYERS: u32 = 20;
const MAX_PIECES: usize = 1000;
const ACTION_MOVE: u32 = 3;
const ACTION_CONNECT: u32 = 4;
@@ -28,7 +28,7 @@ fn generate_puzzle_id() -> [u8; PUZZLE_ID_LEN] {
#[derive(Debug)]
struct Server {
// keep this in memory, since we want to reset it to 0 when the server restarts
- player_counts: Mutex<HashMap<[u8; PUZZLE_ID_LEN], u16>>,
+ player_counts: Mutex<HashMap<[u8; PUZZLE_ID_LEN], u32>>,
wikimedia_featured: Vec<String>,
wikimedia_potd: RwLock<String>,
database: tokio_postgres::Client,
@@ -62,6 +62,28 @@ impl Server {
.await
.is_ok())
}
+ async fn increase_player_count(&self, id: [u8; PUZZLE_ID_LEN]) -> Result<()> {
+ let mut player_counts = self.player_counts.lock().await;
+ let entry = player_counts.entry(id).or_insert(0);
+ if *entry >= MAX_PLAYERS {
+ Err(Error::TooManyPlayers)
+ } else {
+ *entry += 1;
+ Ok(())
+ }
+ }
+ async fn decrease_player_count(&self, id: [u8; PUZZLE_ID_LEN]) -> Result<()> {
+ let mut player_counts = self.player_counts.lock().await;
+ let std::collections::hash_map::Entry::Occupied(mut o) = player_counts.entry(id) else {
+ return Err(Error::BadPuzzleID);
+ };
+ if *o.get() <= 1 {
+ o.remove();
+ } else {
+ *o.get_mut() -= 1;
+ }
+ Ok(())
+ }
async fn set_puzzle_data(
&self,
id: [u8; PUZZLE_ID_LEN],
@@ -186,7 +208,7 @@ impl std::fmt::Display for Error {
match self {
Error::BadPieceID => write!(f, "bad piece ID"),
Error::BadPuzzleID => write!(f, "bad puzzle ID"),
- Error::BadSyntax(s) => write!(f, "bad syntax: {s}"),
+ Error::BadSyntax(s) => write!(f, "{s}"),
Error::ImageURLTooLong => write!(f, "image URL too long"),
Error::TooManyPieces => write!(f, "too many pieces"),
Error::NotJoined => write!(f, "haven't joined a puzzle"),
@@ -326,16 +348,21 @@ async fn handle_websocket(
.as_bytes()
.try_into()
.map_err(|_| Error::BadSyntax("bad join ID"))?;
- let mut player_counts = server.player_counts.lock().await;
- let entry = player_counts.entry(id).or_default();
- if *entry >= MAX_PLAYERS {
- return Err(Error::TooManyPlayers);
- }
- *entry += 1;
- drop(player_counts); // release lock
+ server.increase_player_count(id).await?;
*puzzle_id = Some(id);
let info = get_puzzle_info(server, &id).await?;
ws.send(Message::Binary(info)).await?;
+ } else if let Some(id) = text.strip_prefix("rejoin ") {
+ let id = id
+ .as_bytes()
+ .try_into()
+ .map_err(|_| Error::BadSyntax("bad join ID"))?;
+ if puzzle_id.is_some() {
+ return Err(Error::BadSyntax("unexpected rejoin"));
+ }
+ server.increase_player_count(id).await?;
+ *puzzle_id = Some(id);
+ ws.send(Message::Text("rejoined".to_string())).await?;
} else if text == "poll" {
let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?;
let PieceInfo {
@@ -434,16 +461,11 @@ async fn handle_connection(server: &Server, conn: &mut tokio::net::TcpStream) ->
ws.send(Message::Text(format!("error {e}"))).await?;
};
if let Some(puzzle_id) = puzzle_id {
- *server
- .player_counts
- .lock()
- .await
- .entry(puzzle_id)
- .or_insert_with(|| {
- eprintln!("negative player count??");
- // prevent underflow
- 1
- }) -= 1;
+ if let Err(e) = server.decrease_player_count(puzzle_id).await {
+ eprintln!(
+ "unexpected error while decreasing player count for puzzle {puzzle_id:?}: {e}"
+ );
+ }
}
status
}