From bc34326b935d8d460c3a14951237a744d12d7de3 Mon Sep 17 00:00:00 2001 From: pommicket Date: Tue, 13 Aug 2024 00:08:45 -0400 Subject: start migration to postgres --- server/src/main.rs | 281 +++++++++++++++++------------------------------------ 1 file changed, 91 insertions(+), 190 deletions(-) (limited to 'server/src/main.rs') diff --git a/server/src/main.rs b/server/src/main.rs index f695910..eb9674b 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,3 +1,6 @@ +#![allow(dead_code)] // TODO : delete me +#![allow(unused_variables)] // TODO : delete me + use futures_util::{SinkExt, StreamExt}; use rand::seq::SliceRandom; use rand::Rng; @@ -9,6 +12,7 @@ use std::time::{Duration, SystemTime}; use tokio::io::AsyncWriteExt; use tokio::sync::{Mutex, RwLock}; use tungstenite::protocol::Message; +use zerocopy::AsBytes; const PUZZLE_ID_CHARSET: &[u8] = b"23456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ"; const PUZZLE_ID_LEN: usize = 7; @@ -21,19 +25,48 @@ fn generate_puzzle_id() -> [u8; PUZZLE_ID_LEN] { #[derive(Debug)] struct Server { - puzzles: sled::Tree, - pieces: sled::Tree, - connectivity: sled::Tree, // keep this in memory, since we want to reset it to 0 when the server restarts player_counts: Mutex>, wikimedia_featured: Vec, wikimedia_potd: RwLock, + database: tokio_postgres::Client, +} + + +impl Server { + async fn create_table_if_not_exists(&self) -> Result<()> { + todo!() + } + async fn try_register_id(&self, id: [u8; PUZZLE_ID_LEN]) -> Result { + todo!() + } + async fn set_puzzle_data(&self, id: [u8; PUZZLE_ID_LEN], width: u8, height: u8, url: &str, nib_types: Vec, piece_positions: Vec, connectivity_data: Vec) -> Result<()> { + todo!() + } + async fn move_piece(&self, piece: usize, x: f32, y: f32) -> Result<()> { + todo!() + } + async fn connect_pieces(&self, piece1: usize, piece2: usize) -> Result<()> { + todo!() + } + async fn get_connectivity(&self, id: [u8; PUZZLE_ID_LEN]) -> Result> { + todo!() + } + async fn get_positions(&self, id: [u8; PUZZLE_ID_LEN]) -> Result> { + todo!() + } + async fn get_details(&self, id: [u8; PUZZLE_ID_LEN]) -> Result<(u8, u8, String)> { + todo!() + } + async fn sweep(&self) -> Result<()> { + todo!() + } } #[derive(Debug)] enum Error { Tungstenite(tungstenite::Error), - Sled(sled::Error), + Postgres(tokio_postgres::Error), IO(std::io::Error), UTF8(std::str::Utf8Error), BadPuzzleID, @@ -55,7 +88,7 @@ impl std::fmt::Display for Error { Error::TooManyPieces => write!(f, "too many pieces"), Error::NotJoined => write!(f, "haven't joined a puzzle"), Error::TooManyPlayers => write!(f, "too many players"), - Error::Sled(e) => write!(f, "{e}"), + Error::Postgres(e) => write!(f, "{e}"), Error::IO(e) => write!(f, "{e}"), Error::UTF8(e) => write!(f, "{e}"), Error::Tungstenite(e) => write!(f, "{e}"), @@ -63,9 +96,9 @@ impl std::fmt::Display for Error { } } -impl From for Error { - fn from(value: sled::Error) -> Self { - Self::Sled(value) +impl From for Error { + fn from(value: tokio_postgres::Error) -> Self { + Self::Postgres(value) } } @@ -87,21 +120,21 @@ impl From for Error { type Result = std::result::Result; -fn get_puzzle_info(server: &Server, id: &[u8]) -> Result> { - if id.len() != PUZZLE_ID_LEN { - return Err(Error::BadPuzzleID); - } - let mut data = vec![1, 0, 0, 0, 0, 0, 0, 0]; // opcode + padding - let puzzle = server.puzzles.get(id)?.ok_or(Error::BadPuzzleID)?; - data.extend_from_slice(&puzzle); +async fn get_puzzle_info(server: &Server, id: &[u8]) -> Result> { + let id: [u8; PUZZLE_ID_LEN] = id.try_into().map_err(|_| Error::BadPuzzleID)?; + let mut data = vec![1]; + let (width, height, url) = server.get_details(id).await?; + data.push(width); + data.push(height); + data.extend(url.as_bytes()); while data.len() % 8 != 0 { // padding data.push(0); } - let pieces = server.pieces.get(id)?.ok_or(Error::BadPuzzleID)?; - data.extend_from_slice(&pieces); - let connectivity = server.connectivity.get(id)?.ok_or(Error::BadPuzzleID)?; - data.extend_from_slice(&connectivity); + let pieces = server.get_positions(id).await?; + data.extend_from_slice(pieces.as_bytes()); + let connectivity = server.get_connectivity(id).await?; + data.extend_from_slice(connectivity.as_bytes()); Ok(data) } @@ -136,82 +169,46 @@ async fn handle_websocket( if (width as u16) * (height as u16) > 1000 { return Err(Error::TooManyPieces); } - let mut puzzle_data = vec![width, height]; - let timestamp: u64 = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .expect("time went backwards :/") - .as_secs(); - for byte in timestamp.to_le_bytes() { - puzzle_data.push(byte); - } - // pick nib types + let nib_count = 2 * (width as usize) * (height as usize) - (width as usize) - (height as usize); + let mut nib_types: Vec = Vec::with_capacity(nib_count); + let mut piece_positions: Vec = Vec::with_capacity((width as usize) * (height as usize) * 2); { let mut rng = rand::thread_rng(); - for _ in 0..2u16 * (width as u16) * (height as u16) - - (width as u16) - (height as u16) - { - puzzle_data.push(rng.gen()); - puzzle_data.push(rng.gen()); + // pick nib types + for _ in 0..nib_count { + nib_types.push(rng.gen()); } - } - // URL - puzzle_data.extend(url.as_bytes()); - puzzle_data.push(0); - // puzzle ID - let mut id; - loop { - id = generate_puzzle_id(); - let data = std::mem::take(&mut puzzle_data); - if server - .puzzles - .compare_and_swap(id, None::<&'static [u8; 0]>, Some(&data[..]))? - .is_ok() - { - break; - } - } - drop(puzzle_data); // should be empty now - *puzzle_id = Some(id); - let pieces_data: Box<[u8]>; - { - let mut rng = rand::thread_rng(); - let mut positions = vec![]; - positions.reserve_exact((width as usize) * (height as usize)); - // positions + // pick piece positions for y in 0..(height as u16) { for x in 0..(width as u16) { let dx: f32 = rng.gen_range(0.0..0.5); let dy: f32 = rng.gen_range(0.0..0.5); - positions.push([ - (x as f32 + dx) / ((width + 1) as f32), - (y as f32 + dy) / ((height + 1) as f32), - ]); + piece_positions.push((x as f32 + dx) / ((width + 1) as f32)); + piece_positions.push((y as f32 + dy) / ((height + 1) as f32)); } } - positions.shuffle(&mut rng); - // rust isn't smart enough to do the zero-copy with f32::to_le_bytes and Vec::into_flattened - let ptr: *mut [[f32; 2]] = Box::into_raw(positions.into_boxed_slice()); - let ptr: *mut [u8] = std::ptr::slice_from_raw_parts_mut( - ptr.cast(), - (width as usize) * (height as usize) * 8, - ); - // evil unsafe code >:3 - pieces_data = unsafe { Box::from_raw(ptr) }; + piece_positions.shuffle(&mut rng); } - server.pieces.insert(id, pieces_data)?; - let mut connectivity_data = - Vec::with_capacity((width as usize) * (height as usize) * 2); + let mut connectivity_data: Vec = + Vec::with_capacity((width as usize) * (height as usize)); for i in 0..(width as u16) * (height as u16) { - connectivity_data.extend(i.to_le_bytes()); + connectivity_data.push(i); } - server.connectivity.insert(id, connectivity_data)?; + let mut id; + loop { + id = generate_puzzle_id(); + if server.try_register_id(id).await? { + break; + } + } + server.set_puzzle_data(id, width, height, url, nib_types, piece_positions, connectivity_data).await?; server.player_counts.lock().await.insert(id, 1); ws.send(Message::Text(format!( "id: {}", std::str::from_utf8(&id).expect("puzzle ID has bad utf-8???") ))) .await?; - let info = get_puzzle_info(server, &id)?; + let info = get_puzzle_info(server, &id).await?; ws.send(Message::Binary(info)).await?; } else if let Some(id) = text.strip_prefix("join ") { let id = id.as_bytes().try_into().map_err(|_| Error::BadSyntax)?; @@ -223,17 +220,10 @@ async fn handle_websocket( *entry += 1; drop(player_counts); // release lock *puzzle_id = Some(id); - let info = get_puzzle_info(server, &id)?; + let info = get_puzzle_info(server, &id).await?; ws.send(Message::Binary(info)).await?; } else if text.starts_with("move ") { let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?; - #[derive(Clone, Copy)] - struct Motion { - piece: usize, - x: f32, - y: f32, - } - let mut motions = vec![]; for line in text.split('\n') { let mut parts = line.split(' '); parts.next(); // skip "move" @@ -252,29 +242,7 @@ async fn handle_websocket( .ok_or(Error::BadSyntax)? .parse() .map_err(|_| Error::BadSyntax)?; - motions.push(Motion { piece, x, y }); - } - let mut error = None; - server - .pieces - .fetch_and_update(puzzle_id, |curr_pieces: Option<&[u8]>| { - let Some(curr_pieces) = curr_pieces else { - error = Some(Error::BadPuzzleID); - return None; - }; - let mut new_pieces = curr_pieces.to_vec(); - for Motion { piece, x, y } in motions.iter().copied() { - let Some(slice) = new_pieces.get_mut(8 * piece..8 * piece + 8) else { - error = Some(Error::BadPieceID); - break; - }; - slice[0..4].copy_from_slice(&x.to_le_bytes()); - slice[4..8].copy_from_slice(&y.to_le_bytes()); - } - Some(new_pieces) - })?; - if let Some(error) = error { - return Err(error); + server.move_piece(piece, x, y).await?; } ws.send(Message::Text("ack".to_string())).await?; } else if let Some(data) = text.strip_prefix("connect ") { @@ -290,52 +258,12 @@ async fn handle_websocket( .ok_or(Error::BadSyntax)? .parse() .map_err(|_| Error::BadSyntax)?; - let mut error = None; - server - .connectivity - .fetch_and_update(puzzle_id, |curr_connectivity| { - let Some(curr_connectivity) = curr_connectivity else { - error = Some(Error::BadPuzzleID); - return None; - }; - let mut new_connectivity = curr_connectivity.to_vec(); - if piece1 >= curr_connectivity.len() / 2 - || piece2 >= curr_connectivity.len() / 2 - { - error = Some(Error::BadPieceID); - return Some(new_connectivity); - } - let piece2_group = u16::from_le_bytes([ - curr_connectivity[piece2 * 2], - curr_connectivity[piece2 * 2 + 1], - ]); - let a = curr_connectivity[piece1 * 2]; - let b = curr_connectivity[piece1 * 2 + 1]; - for piece in 0..curr_connectivity.len() / 2 { - let piece_group = u16::from_le_bytes([ - curr_connectivity[piece * 2], - curr_connectivity[piece * 2 + 1], - ]); - if piece_group == piece2_group { - new_connectivity[piece * 2] = a; - new_connectivity[piece * 2 + 1] = b; - } - } - Some(new_connectivity) - })?; - if let Some(error) = error { - return Err(error); - } + server.connect_pieces(piece1, piece2).await?; } else if text == "poll" { let puzzle_id = puzzle_id.ok_or(Error::NotJoined)?; - let pieces = server.pieces.get(puzzle_id)?.ok_or(Error::BadPuzzleID)?; - let connectivity = server - .connectivity - .get(puzzle_id)? - .ok_or(Error::BadPuzzleID)?; let mut data = vec![2, 0, 0, 0, 0, 0, 0, 0]; // opcode / version number + padding - data.extend_from_slice(&pieces); - data.extend_from_slice(&connectivity); + data.extend_from_slice(server.get_positions(puzzle_id).await?.as_bytes()); + data.extend_from_slice(server.get_connectivity(puzzle_id).await?.as_bytes()); ws.send(Message::Binary(data)).await?; } else if text == "randomFeaturedWikimedia" { let choice = rand::thread_rng().gen_range(0..server.wikimedia_featured.len()); @@ -427,22 +355,23 @@ async fn main() { let server_arc: Arc = Arc::new({ let wikimedia_featured = read_to_lines("featuredpictures.txt").expect("Couldn't read featuredpictures.txt"); - let db = sled::open("database.sled").expect("error opening database"); - let puzzles = db.open_tree("PUZZLES").expect("error opening puzzles tree"); - let pieces = db.open_tree("PIECES").expect("error opening pieces tree"); - let connectivity = db - .open_tree("CONNECTIVITY") - .expect("error opening connectivity tree"); let potd = get_potd().await; + let (client, connection) = tokio_postgres::connect("host=/var/run/postgresql dbname=jigsaw", tokio_postgres::NoTls).await.expect("Couldn't connect to database"); + + // docs say: "The connection object performs the actual communication with the database, so spawn it off to run on its own." + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("connection error: {}", e); + } + }); Server { - puzzles, - pieces, player_counts: Mutex::new(HashMap::new()), - connectivity, + database: client, wikimedia_potd: RwLock::new(potd), wikimedia_featured, } }); + server_arc.create_table_if_not_exists().await.expect("error creating table"); let server_arc_clone = server_arc.clone(); tokio::task::spawn(async move { let server: &Server = server_arc_clone.as_ref(); @@ -466,36 +395,8 @@ async fn main() { loop { // TODO : sweep let now = SystemTime::now(); - let mut to_delete = vec![]; - for item in server.puzzles.iter() { - let (key, value) = item.expect("sweep failed to read database"); - let timestamp: [u8; 8] = value[2..2 + 8].try_into().unwrap(); - let timestamp = - SystemTime::UNIX_EPOCH + Duration::from_secs(u64::from_le_bytes(timestamp)); - if now.duration_since(timestamp).unwrap_or_default() - >= Duration::from_secs(60 * 60 * 24 * 7) - { - // delete puzzles created at least 1 week ago - to_delete.push(key); - } - } - for key in to_delete { - // technically there is a race condition here but stop being silly - server - .puzzles - .remove(&key) - .expect("sweep failed to delete puzzle"); - server - .pieces - .remove(&key) - .expect("sweep failed to delete pieces"); - server - .connectivity - .remove(&key) - .expect("sweep failed to delete connectivity"); - if let Some(key) = <[u8; PUZZLE_ID_LEN]>::try_from(&key[..]).ok() { - server.player_counts.lock().await.remove(&key); - } + if let Err(e) = server.sweep().await { + eprintln!("error sweeping DB: {e}"); } tokio::time::sleep(std::time::Duration::from_secs(3600)).await; } -- cgit v1.2.3