diff options
Diffstat (limited to 'server/src/main.rs')
-rw-r--r-- | server/src/main.rs | 99 |
1 files changed, 65 insertions, 34 deletions
diff --git a/server/src/main.rs b/server/src/main.rs index 907dc3e..a6ffd2c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -7,7 +7,6 @@ use safe_transmute::{transmute_many_pedantic, transmute_to_bytes}; use std::collections::HashMap; use std::io::prelude::*; use std::net::SocketAddr; -use std::sync::Arc; use std::time::{Duration, SystemTime}; use tokio::io::AsyncWriteExt; use tokio::sync::{Mutex, RwLock}; @@ -30,6 +29,12 @@ struct Server { wikimedia_featured: Vec<String>, wikimedia_potd: RwLock<String>, database: tokio_postgres::Client, + create_puzzle: tokio_postgres::Statement, + set_puzzle_data: tokio_postgres::Statement, + move_piece: tokio_postgres::Statement, + connect_pieces: tokio_postgres::Statement, + get_piece_info: tokio_postgres::Statement, + get_puzzle_info: tokio_postgres::Statement, } struct PieceInfo { @@ -70,6 +75,9 @@ impl Server { &[], ) .await?; + self.database + .execute("CREATE INDEX by_id ON puzzles (id)", &[]) + .await?; } Ok(()) } @@ -77,7 +85,7 @@ impl Server { let id = std::str::from_utf8(&id)?; Ok(self .database - .execute("INSERT INTO puzzles (id) VALUES ($1)", &[&id]) + .execute(&self.create_puzzle, &[&id]) .await .is_ok()) } @@ -102,8 +110,7 @@ impl Server { let positions = &piece_positions; self.database .execute( - "UPDATE puzzles SET width = $1, height = $2, url = $3, nib_types = $4, - connectivity = $5, positions = $6 WHERE id = $7", + &self.set_puzzle_data, &[ &width, &height, @@ -132,11 +139,9 @@ impl Server { // NOTE: postgresql arrays start at index 1! let i0 = piece * 2 + 1; let i1 = piece * 2 + 2; - self.database.execute( - "UPDATE puzzles SET positions[$1] = $2, positions[$3] = $4 WHERE id = $5 AND $6 < width * height", - // the $6 < width * height protects against OOB access! - &[&i0, &x, &i1, &y, &id, &piece] - ).await?; + self.database + .execute(&self.move_piece, &[&i0, &x, &i1, &y, &id, &piece]) + .await?; Ok(()) } async fn connect_pieces( @@ -149,22 +154,14 @@ impl Server { // NOTE: postgresql arrays start at index 1! let piece1 = piece1 as i32 + 1; let piece2 = piece2 as i32 + 1; - self.database.execute( - "UPDATE puzzles SET connectivity = array_replace(connectivity, connectivity[$1], connectivity[$2]) WHERE id = $3 AND $4 < width * height AND $5 < width * height", - // the $6 < width * height protects against OOB access! - &[&piece1, &piece2, &id, &piece1, &piece2] - ).await?; + self.database + .execute(&self.connect_pieces, &[&piece1, &piece2, &id]) + .await?; Ok(()) } async fn get_piece_info(&self, id: [u8; PUZZLE_ID_LEN]) -> Result<PieceInfo> { let id = std::str::from_utf8(&id)?; - let rows = self - .database - .query( - "SELECT positions, connectivity FROM puzzles WHERE id = $1", - &[&id], - ) - .await?; + let rows = self.database.query(&self.get_piece_info, &[&id]).await?; let row = &rows[0]; let positions: Vec<f32> = row.try_get(0)?; let connectivity: Vec<i16> = row.try_get(1)?; @@ -175,10 +172,7 @@ impl Server { } async fn get_puzzle_info(&self, id: [u8; PUZZLE_ID_LEN]) -> Result<PuzzleInfo> { let id = std::str::from_utf8(&id)?; - let rows = self.database.query( - "SELECT width, height, url, positions, nib_types, connectivity FROM puzzles WHERE id = $1", - &[&id] - ).await?; + let rows = self.database.query(&self.get_puzzle_info, &[&id]).await?; let row = &rows[0]; let width: i32 = row.try_get(0)?; let height: i32 = row.try_get(1)?; @@ -522,7 +516,7 @@ async fn main() { } }; let start_time = SystemTime::now(); - let server_arc: Arc<Server> = Arc::new({ + let server: &'static Server = Box::leak(Box::new({ let wikimedia_featured = read_to_lines("featuredpictures.txt").expect("Couldn't read featuredpictures.txt"); let potd = get_potd().await; @@ -539,20 +533,61 @@ async fn main() { eprintln!("connection error: {}", e); } }); + use tokio_postgres::types::Type; + let create_puzzle = client + .prepare_typed("INSERT INTO puzzles (id) VALUES ($1)", &[Type::BPCHAR]) + .await + .expect("couldn't prepare create_puzzle statement"); + let set_puzzle_data = client + .prepare_typed( + "UPDATE puzzles SET width = $1, height = $2, url = $3, nib_types = $4, + connectivity = $5, positions = $6 WHERE id = $7", + &[ + Type::INT4, + Type::INT4, + Type::TEXT, + Type::INT2_ARRAY, + Type::INT2_ARRAY, + Type::FLOAT4_ARRAY, + Type::BPCHAR, + ], + ) + .await + .expect("couldn't prepare set_puzzle_data statement"); + let move_piece = client.prepare_typed("UPDATE puzzles SET positions[$1] = $2, positions[$3] = $4 WHERE id = $5 AND $6 < width * height", + &[Type::INT4, Type::FLOAT4, Type::INT4, Type::FLOAT4, Type::BPCHAR, Type::INT4]) + .await.expect("couldn't prepare move_piece statement"); + let connect_pieces = client.prepare_typed( + "UPDATE puzzles SET connectivity = array_replace(connectivity, connectivity[$1], connectivity[$2]) WHERE id = $3 AND $1 < width * height AND $2 < width * height", + &[Type::INT4, Type::INT4, Type::BPCHAR]) + .await.expect("couldn't prepare connect_pieces statement"); + let get_piece_info = client + .prepare_typed( + "SELECT positions, connectivity FROM puzzles WHERE id = $1", + &[Type::BPCHAR], + ) + .await + .expect("couldn't prepare get_piece_info statement"); + let get_puzzle_info = client.prepare_typed("SELECT width, height, url, positions, nib_types, connectivity FROM puzzles WHERE id = $1", &[Type::BPCHAR]) + .await.expect("couldn't prepare get_puzzle_info statement"); Server { player_counts: Mutex::new(HashMap::new()), + create_puzzle, + set_puzzle_data, + move_piece, + connect_pieces, + get_piece_info, + get_puzzle_info, database: client, wikimedia_potd: RwLock::new(potd), wikimedia_featured, } - }); - server_arc + })); + server .create_table_if_not_exists() .await .expect("error creating table"); - let server_arc_clone = server_arc.clone(); tokio::task::spawn(async move { - let server: &Server = server_arc_clone.as_ref(); fn next_day(t: SystemTime) -> SystemTime { let day = 60 * 60 * 24; let dt = t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs(); @@ -567,9 +602,7 @@ async fn main() { last_time = SystemTime::now(); } }); - let server_arc_clone = server_arc.clone(); tokio::task::spawn(async move { - let server: &Server = server_arc_clone.as_ref(); loop { if let Err(e) = server.sweep().await { eprintln!("error sweeping DB: {e}"); @@ -586,9 +619,7 @@ async fn main() { continue; } }; - let server_arc_clone = server_arc.clone(); tokio::task::spawn(async move { - let server: &Server = server_arc_clone.as_ref(); match handle_connection(server, &mut stream).await { Ok(()) => {} Err(e) => { |