summaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authorpommicket <pommicket@gmail.com>2024-08-11 23:46:21 -0400
committerpommicket <pommicket@gmail.com>2024-08-11 23:46:21 -0400
commitb9fd18dc538b59d5a901057466066e2e62c625c7 (patch)
tree1aedbdfc33ea23027780c408bdeefbac58f861a3 /server
parentc140f2975f076eb481400f1ac7fdec2d00ab73b1 (diff)
enforce player limit and some other things
Diffstat (limited to 'server')
-rw-r--r--server/Cargo.lock7
-rw-r--r--server/Cargo.toml1
-rw-r--r--server/src/main.rs359
3 files changed, 236 insertions, 131 deletions
diff --git a/server/Cargo.lock b/server/Cargo.lock
index c7e5745..610d853 100644
--- a/server/Cargo.lock
+++ b/server/Cargo.lock
@@ -18,12 +18,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
-name = "anyhow"
-version = "1.0.86"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da"
-
-[[package]]
name = "autocfg"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -280,7 +274,6 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
name = "jigsaw-server"
version = "0.1.0"
dependencies = [
- "anyhow",
"futures-util",
"rand",
"sled",
diff --git a/server/Cargo.toml b/server/Cargo.toml
index 49d409e..130968f 100644
--- a/server/Cargo.toml
+++ b/server/Cargo.toml
@@ -4,7 +4,6 @@ version = "0.1.0"
edition = "2021"
[dependencies]
-anyhow = "1.0.86"
futures-util = "0.3"
rand = { version = "0.8.5", features = ["std", "std_rng"] }
sled = "0.34.7"
diff --git a/server/src/main.rs b/server/src/main.rs
index c289c89..f695910 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -1,16 +1,18 @@
-use anyhow::anyhow;
use futures_util::{SinkExt, StreamExt};
use rand::seq::SliceRandom;
use rand::Rng;
+use std::collections::HashMap;
use std::io::prelude::*;
use std::net::SocketAddr;
+use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::io::AsyncWriteExt;
-use tokio::sync::RwLock;
+use tokio::sync::{Mutex, RwLock};
use tungstenite::protocol::Message;
const PUZZLE_ID_CHARSET: &[u8] = b"23456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ";
const PUZZLE_ID_LEN: usize = 7;
+const MAX_PLAYERS: u16 = 20;
fn generate_puzzle_id() -> [u8; PUZZLE_ID_LEN] {
let mut rng = rand::thread_rng();
@@ -22,51 +24,92 @@ struct Server {
puzzles: sled::Tree,
pieces: sled::Tree,
connectivity: sled::Tree,
+ // keep this in memory, since we want to reset it to 0 when the server restarts
+ player_counts: Mutex<HashMap<[u8; PUZZLE_ID_LEN], u16>>,
wikimedia_featured: Vec<String>,
wikimedia_potd: RwLock<String>,
}
-fn get_puzzle_info(server: &Server, id: &[u8]) -> anyhow::Result<Vec<u8>> {
+#[derive(Debug)]
+enum Error {
+ Tungstenite(tungstenite::Error),
+ Sled(sled::Error),
+ IO(std::io::Error),
+ UTF8(std::str::Utf8Error),
+ BadPuzzleID,
+ BadPieceID,
+ BadSyntax,
+ ImageURLTooLong,
+ TooManyPieces,
+ TooManyPlayers,
+ NotJoined,
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Error::BadPieceID => write!(f, "bad piece ID"),
+ Error::BadPuzzleID => write!(f, "bad puzzle ID"),
+ Error::BadSyntax => write!(f, "bad syntax"),
+ Error::ImageURLTooLong => write!(f, "image URL too long"),
+ Error::TooManyPieces => write!(f, "too many pieces"),
+ Error::NotJoined => write!(f, "haven't joined a puzzle"),
+ Error::TooManyPlayers => write!(f, "too many players"),
+ Error::Sled(e) => write!(f, "{e}"),
+ Error::IO(e) => write!(f, "{e}"),
+ Error::UTF8(e) => write!(f, "{e}"),
+ Error::Tungstenite(e) => write!(f, "{e}"),
+ }
+ }
+}
+
+impl From<sled::Error> for Error {
+ fn from(value: sled::Error) -> Self {
+ Self::Sled(value)
+ }
+}
+
+impl From<tungstenite::Error> for Error {
+ fn from(value: tungstenite::Error) -> Self {
+ Self::Tungstenite(value)
+ }
+}
+impl From<std::io::Error> for Error {
+ fn from(value: std::io::Error) -> Self {
+ Self::IO(value)
+ }
+}
+impl From<std::str::Utf8Error> for Error {
+ fn from(value: std::str::Utf8Error) -> Self {
+ Self::UTF8(value)
+ }
+}
+
+type Result<T> = std::result::Result<T, Error>;
+
+fn get_puzzle_info(server: &Server, id: &[u8]) -> Result<Vec<u8>> {
if id.len() != PUZZLE_ID_LEN {
- return Err(anyhow!("bad puzzle ID"));
+ return Err(Error::BadPuzzleID);
}
let mut data = vec![1, 0, 0, 0, 0, 0, 0, 0]; // opcode + padding
- let puzzle = server
- .puzzles
- .get(id)?
- .ok_or_else(|| anyhow!("bad puzzle ID"))?;
+ let puzzle = server.puzzles.get(id)?.ok_or(Error::BadPuzzleID)?;
data.extend_from_slice(&puzzle);
while data.len() % 8 != 0 {
// padding
data.push(0);
}
- let pieces = server
- .pieces
- .get(id)?
- .ok_or_else(|| anyhow!("bad puzzle ID"))?;
+ let pieces = server.pieces.get(id)?.ok_or(Error::BadPuzzleID)?;
data.extend_from_slice(&pieces);
- let connectivity = server
- .connectivity
- .get(id)?
- .ok_or_else(|| anyhow!("bad puzzle ID"))?;
+ let connectivity = server.connectivity.get(id)?.ok_or(Error::BadPuzzleID)?;
data.extend_from_slice(&connectivity);
Ok(data)
}
-async fn handle_connection(
+async fn handle_websocket(
server: &Server,
- conn: &mut tokio::net::TcpStream,
-) -> anyhow::Result<()> {
- let mut ws = tokio_tungstenite::accept_async_with_config(
- conn,
- Some(tungstenite::protocol::WebSocketConfig {
- max_message_size: Some(65536),
- max_frame_size: Some(65536),
- ..Default::default()
- }),
- )
- .await?;
- let mut puzzle_id = None;
+ puzzle_id: &mut Option<[u8; PUZZLE_ID_LEN]>,
+ ws: &mut tokio_tungstenite::WebSocketStream<&mut tokio::net::TcpStream>,
+) -> Result<()> {
while let Some(message) = ws.next().await {
let message = message?;
if matches!(message, Message::Close(_)) {
@@ -76,14 +119,22 @@ async fn handle_connection(
let text = text.trim();
if let Some(dimensions) = text.strip_prefix("new ") {
let mut parts = dimensions.split(' ');
- let width: u8 = parts.next().ok_or_else(|| anyhow!("no width"))?.parse()?;
- let height: u8 = parts.next().ok_or_else(|| anyhow!("no height"))?.parse()?;
- let url: &str = parts.next().ok_or_else(|| anyhow!("no url"))?;
+ let width: u8 = parts
+ .next()
+ .ok_or(Error::BadSyntax)?
+ .parse()
+ .map_err(|_| Error::BadSyntax)?;
+ let height: u8 = parts
+ .next()
+ .ok_or(Error::BadSyntax)?
+ .parse()
+ .map_err(|_| Error::BadSyntax)?;
+ let url: &str = parts.next().ok_or(Error::BadSyntax)?;
if url.len() > 255 {
- return Err(anyhow!("image URL too long"));
+ return Err(Error::ImageURLTooLong);
}
if (width as u16) * (height as u16) > 1000 {
- return Err(anyhow!("too many pieces"));
+ return Err(Error::TooManyPieces);
}
let mut puzzle_data = vec![width, height];
let timestamp: u64 = SystemTime::now()
@@ -120,7 +171,7 @@ async fn handle_connection(
}
}
drop(puzzle_data); // should be empty now
- puzzle_id = Some(id);
+ *puzzle_id = Some(id);
let pieces_data: Box<[u8]>;
{
let mut rng = rand::thread_rng();
@@ -154,17 +205,28 @@ async fn handle_connection(
connectivity_data.extend(i.to_le_bytes());
}
server.connectivity.insert(id, connectivity_data)?;
- ws.send(Message::Text(format!("id: {}", std::str::from_utf8(&id)?)))
- .await?;
+ server.player_counts.lock().await.insert(id, 1);
+ ws.send(Message::Text(format!(
+ "id: {}",
+ std::str::from_utf8(&id).expect("puzzle ID has bad utf-8???")
+ )))
+ .await?;
let info = get_puzzle_info(server, &id)?;
ws.send(Message::Binary(info)).await?;
} else if let Some(id) = text.strip_prefix("join ") {
- let id = id.as_bytes().try_into()?;
- puzzle_id = Some(id);
+ let id = id.as_bytes().try_into().map_err(|_| Error::BadSyntax)?;
+ let mut player_counts = server.player_counts.lock().await;
+ let entry = player_counts.entry(id).or_default();
+ if *entry >= MAX_PLAYERS {
+ return Err(Error::TooManyPlayers);
+ }
+ *entry += 1;
+ drop(player_counts); // release lock
+ *puzzle_id = Some(id);
let info = get_puzzle_info(server, &id)?;
ws.send(Message::Binary(info)).await?;
} else if text.starts_with("move ") {
- let puzzle_id = puzzle_id.ok_or_else(|| anyhow!("move without puzzle ID"))?;
+ let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?;
#[derive(Clone, Copy)]
struct Motion {
piece: usize,
@@ -175,93 +237,102 @@ async fn handle_connection(
for line in text.split('\n') {
let mut parts = line.split(' ');
parts.next(); // skip "move"
- let piece: usize =
- parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?;
- let x: f32 = parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?;
- let y: f32 = parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?;
+ let piece: usize = parts
+ .next()
+ .ok_or(Error::BadSyntax)?
+ .parse()
+ .map_err(|_| Error::BadSyntax)?;
+ let x: f32 = parts
+ .next()
+ .ok_or(Error::BadSyntax)?
+ .parse()
+ .map_err(|_| Error::BadSyntax)?;
+ let y: f32 = parts
+ .next()
+ .ok_or(Error::BadSyntax)?
+ .parse()
+ .map_err(|_| Error::BadSyntax)?;
motions.push(Motion { piece, x, y });
}
- loop {
- let curr_pieces = server
- .pieces
- .get(puzzle_id)?
- .ok_or_else(|| anyhow!("bad puzzle ID"))?;
- let mut new_pieces = curr_pieces.to_vec();
- for Motion { piece, x, y } in motions.iter().copied() {
- new_pieces
- .get_mut(8 * piece..8 * piece + 4)
- .ok_or_else(|| anyhow!("bad piece ID"))?
- .copy_from_slice(&x.to_le_bytes());
- new_pieces
- .get_mut(8 * piece + 4..8 * piece + 8)
- .ok_or_else(|| anyhow!("bad piece ID"))?
- .copy_from_slice(&y.to_le_bytes());
- }
- if server
- .pieces
- .compare_and_swap(puzzle_id, Some(curr_pieces), Some(new_pieces))?
- .is_ok()
- {
- break;
- }
- tokio::time::sleep(std::time::Duration::from_millis(1)).await; // yield maybe (don't let contention hog resources)
+ let mut error = None;
+ server
+ .pieces
+ .fetch_and_update(puzzle_id, |curr_pieces: Option<&[u8]>| {
+ let Some(curr_pieces) = curr_pieces else {
+ error = Some(Error::BadPuzzleID);
+ return None;
+ };
+ let mut new_pieces = curr_pieces.to_vec();
+ for Motion { piece, x, y } in motions.iter().copied() {
+ let Some(slice) = new_pieces.get_mut(8 * piece..8 * piece + 8) else {
+ error = Some(Error::BadPieceID);
+ break;
+ };
+ slice[0..4].copy_from_slice(&x.to_le_bytes());
+ slice[4..8].copy_from_slice(&y.to_le_bytes());
+ }
+ Some(new_pieces)
+ })?;
+ if let Some(error) = error {
+ return Err(error);
}
ws.send(Message::Text("ack".to_string())).await?;
} else if let Some(data) = text.strip_prefix("connect ") {
let mut parts = data.split(' ');
- let puzzle_id = puzzle_id.ok_or_else(|| anyhow!("connect without puzzle ID"))?;
- let piece1: usize = parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?;
- let piece2: usize = parts.next().ok_or_else(|| anyhow!("bad syntax"))?.parse()?;
- loop {
- let curr_connectivity = server
- .connectivity
- .get(puzzle_id)?
- .ok_or_else(|| anyhow!("bad puzzle ID"))?;
- let mut new_connectivity = curr_connectivity.to_vec();
- if piece1 >= curr_connectivity.len() / 2
- || piece2 >= curr_connectivity.len() / 2
- {
- return Err(anyhow!("bad piece ID"));
- }
- let piece2_group = u16::from_le_bytes([
- curr_connectivity[piece2 * 2],
- curr_connectivity[piece2 * 2 + 1],
- ]);
- let a = curr_connectivity[piece1 * 2];
- let b = curr_connectivity[piece1 * 2 + 1];
- for piece in 0..curr_connectivity.len() / 2 {
- let piece_group = u16::from_le_bytes([
- curr_connectivity[piece * 2],
- curr_connectivity[piece * 2 + 1],
+ let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?;
+ let piece1: usize = parts
+ .next()
+ .ok_or(Error::BadSyntax)?
+ .parse()
+ .map_err(|_| Error::BadSyntax)?;
+ let piece2: usize = parts
+ .next()
+ .ok_or(Error::BadSyntax)?
+ .parse()
+ .map_err(|_| Error::BadSyntax)?;
+ let mut error = None;
+ server
+ .connectivity
+ .fetch_and_update(puzzle_id, |curr_connectivity| {
+ let Some(curr_connectivity) = curr_connectivity else {
+ error = Some(Error::BadPuzzleID);
+ return None;
+ };
+ let mut new_connectivity = curr_connectivity.to_vec();
+ if piece1 >= curr_connectivity.len() / 2
+ || piece2 >= curr_connectivity.len() / 2
+ {
+ error = Some(Error::BadPieceID);
+ return Some(new_connectivity);
+ }
+ let piece2_group = u16::from_le_bytes([
+ curr_connectivity[piece2 * 2],
+ curr_connectivity[piece2 * 2 + 1],
]);
- if piece_group == piece2_group {
- new_connectivity[piece * 2] = a;
- new_connectivity[piece * 2 + 1] = b;
+ let a = curr_connectivity[piece1 * 2];
+ let b = curr_connectivity[piece1 * 2 + 1];
+ for piece in 0..curr_connectivity.len() / 2 {
+ let piece_group = u16::from_le_bytes([
+ curr_connectivity[piece * 2],
+ curr_connectivity[piece * 2 + 1],
+ ]);
+ if piece_group == piece2_group {
+ new_connectivity[piece * 2] = a;
+ new_connectivity[piece * 2 + 1] = b;
+ }
}
- }
- if server
- .connectivity
- .compare_and_swap(
- puzzle_id,
- Some(curr_connectivity),
- Some(new_connectivity),
- )?
- .is_ok()
- {
- break;
- }
- tokio::time::sleep(std::time::Duration::from_millis(1)).await; // yield maybe (don't let contention hog resources)
+ Some(new_connectivity)
+ })?;
+ if let Some(error) = error {
+ return Err(error);
}
} else if text == "poll" {
- let puzzle_id = puzzle_id.ok_or_else(|| anyhow!("poll without puzzle ID"))?;
- let pieces = server
- .pieces
- .get(puzzle_id)?
- .ok_or_else(|| anyhow!("bad puzzle ID"))?;
+ let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?;
+ let pieces = server.pieces.get(puzzle_id)?.ok_or(Error::BadPuzzleID)?;
let connectivity = server
.connectivity
.get(puzzle_id)?
- .ok_or_else(|| anyhow!("bad puzzle ID"))?;
+ .ok_or(Error::BadPuzzleID)?;
let mut data = vec![2, 0, 0, 0, 0, 0, 0, 0]; // opcode / version number + padding
data.extend_from_slice(&pieces);
data.extend_from_slice(&connectivity);
@@ -285,18 +356,51 @@ async fn handle_connection(
Ok(())
}
+async fn handle_connection(server: &Server, conn: &mut tokio::net::TcpStream) -> Result<()> {
+ let mut puzzle_id = None;
+ let mut ws = tokio_tungstenite::accept_async_with_config(
+ conn,
+ Some(tungstenite::protocol::WebSocketConfig {
+ max_message_size: Some(65536),
+ max_frame_size: Some(65536),
+ ..Default::default()
+ }),
+ )
+ .await?;
+ let status = handle_websocket(server, &mut puzzle_id, &mut ws).await;
+ if let Err(e) = &status {
+ ws.send(Message::Text(format!("error {e}"))).await?;
+ };
+ if let Some(puzzle_id) = puzzle_id {
+ *server
+ .player_counts
+ .lock()
+ .await
+ .entry(puzzle_id)
+ .or_insert_with(|| {
+ eprintln!("negative player count??");
+ // prevent underflow
+ 1
+ }) -= 1;
+ }
+ status
+}
+
fn read_to_lines(path: &str) -> std::io::Result<Vec<String>> {
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
reader.lines().collect()
}
-async fn try_get_potd() -> anyhow::Result<String> {
+async fn try_get_potd() -> Result<String> {
let output = tokio::process::Command::new("python3")
.arg("potd.py")
.output()
.await?;
- Ok(String::from_utf8(output.stdout)?.trim().to_string())
+ Ok(String::from_utf8(output.stdout)
+ .map_err(|e| e.utf8_error())?
+ .trim()
+ .to_string())
}
async fn get_potd() -> String {
match try_get_potd().await {
@@ -320,8 +424,7 @@ async fn main() {
}
};
let start_time = SystemTime::now();
- // leak this since we need all threads to be able to access this
- let server: &'static Server = Box::leak(Box::new({
+ let server_arc: Arc<Server> = Arc::new({
let wikimedia_featured =
read_to_lines("featuredpictures.txt").expect("Couldn't read featuredpictures.txt");
let db = sled::open("database.sled").expect("error opening database");
@@ -334,12 +437,15 @@ async fn main() {
Server {
puzzles,
pieces,
+ player_counts: Mutex::new(HashMap::new()),
connectivity,
wikimedia_potd: RwLock::new(potd),
wikimedia_featured,
}
- }));
+ });
+ let server_arc_clone = server_arc.clone();
tokio::task::spawn(async move {
+ let server: &Server = server_arc_clone.as_ref();
fn next_day(t: SystemTime) -> SystemTime {
let day = 60 * 60 * 24;
let dt = t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
@@ -354,7 +460,9 @@ async fn main() {
last_time = SystemTime::now();
}
});
- tokio::task::spawn(async {
+ let server_arc_clone = server_arc.clone();
+ tokio::task::spawn(async move {
+ let server: &Server = server_arc_clone.as_ref();
loop {
// TODO : sweep
let now = SystemTime::now();
@@ -376,15 +484,18 @@ async fn main() {
server
.puzzles
.remove(&key)
- .expect("sweep failed to delete entry");
+ .expect("sweep failed to delete puzzle");
server
.pieces
.remove(&key)
- .expect("sweep failed to delete entry");
+ .expect("sweep failed to delete pieces");
server
.connectivity
.remove(&key)
- .expect("sweep failed to delete entry");
+ .expect("sweep failed to delete connectivity");
+ if let Some(key) = <[u8; PUZZLE_ID_LEN]>::try_from(&key[..]).ok() {
+ server.player_counts.lock().await.remove(&key);
+ }
}
tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
}
@@ -397,7 +508,9 @@ async fn main() {
continue;
}
};
+ let server_arc_clone = server_arc.clone();
tokio::task::spawn(async move {
+ let server: &Server = server_arc_clone.as_ref();
match handle_connection(server, &mut stream).await {
Ok(()) => {}
Err(e) => {