summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
authorpommicket <pommicket@gmail.com>2025-08-29 04:25:57 -0400
committerpommicket <pommicket@gmail.com>2025-08-29 04:25:57 -0400
commitaf16284921b65f92601279001531862d1c80cee7 (patch)
treeffabe5598ab96839de4ff44c53fa2ef44f4db062 /src/main.rs
Initial commit (v. 0.1.0)v0.1.0
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs222
1 files changed, 222 insertions, 0 deletions
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..ffa1892
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,222 @@
+pub mod argparse;
+
+use std::collections::HashMap;
+use std::io::{self, prelude::*};
+use std::net::{IpAddr, TcpListener, TcpStream};
+use std::process::ExitCode;
+use std::sync::{
+ Arc,
+ atomic::{AtomicBool, Ordering},
+};
+use std::thread::sleep;
+use std::time::Duration;
+
+#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
+enum Flag {
+ Help,
+ Version,
+}
+
+#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
+enum Param {
+ Port,
+ Address,
+ Timeout,
+ MaxConnections,
+}
+
+fn print_version() {
+ println!(
+ "{} v. {}",
+ env!("CARGO_PKG_NAME"),
+ env!("CARGO_PKG_VERSION")
+ );
+}
+
+fn print_doc() {
+ print_version();
+ println!("{}", include_str!("../README.txt"));
+}
+
+#[derive(Debug)]
+struct ConnectionSettings {
+ timeout: u32,
+ buffer: [u8; 1024],
+}
+
+const SLEEP_MS_PER_ROUND: u32 = 50;
+
+impl ConnectionSettings {
+ /// handle this connection as much as possible *without blocking*
+ #[must_use]
+ fn handle_connection(&mut self, addr: IpAddr, conn: &mut Connection) -> bool {
+ let buffer = &mut self.buffer;
+ conn.age_ms += SLEEP_MS_PER_ROUND;
+ if conn.age_ms > self.timeout * 1000 {
+ return false;
+ }
+ if conn.state < 4 {
+ let n = match conn.stream.read(buffer) {
+ Ok(n) => n,
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => 0,
+ Err(e) => {
+ eprintln!("WARNING: error reading from connection: {e}");
+ return false;
+ }
+ };
+ for &c in &buffer[..n] {
+ if c != b'\r' && c != b'\n' {
+ continue;
+ }
+ let expected = if conn.state % 2 == 0 { b'\r' } else { b'\n' };
+ if c == expected {
+ conn.state += 1;
+ if conn.state == 4 {
+ break;
+ }
+ }
+ }
+ }
+ if conn.state == 4 {
+ // 90B is more than long enough for our response
+ write!(
+ &mut buffer[..90],
+ "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n{addr}\n\0"
+ )
+ .expect("?? writing address to buffer failed??");
+ let length = buffer[..90]
+ .iter()
+ .position(|&c| c == 0)
+ .expect("?? no null byte even though we just wrote one??");
+ match conn
+ .stream
+ .write(&buffer[usize::from(conn.written)..length])
+ {
+ Ok(n) => {
+ conn.written += n as u8;
+ if conn.written >= length as u8 {
+ // hooray! done sending response
+ return false;
+ }
+ }
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
+ // keep connection around
+ }
+ Err(e) => {
+ eprintln!("WARNING: error writing to connection: {e}");
+ return false;
+ }
+ }
+ }
+ true
+ }
+}
+
+#[derive(Debug)]
+struct Connection {
+ stream: TcpStream,
+ /// rough time connection has been around in milliseconds
+ age_ms: u32,
+ /// number of characters out of \r\n\r\n (end of HTTP headers) that we have
+ state: u8,
+ /// number of characters of address which have been written
+ written: u8,
+}
+
+fn try_main() -> Result<(), Box<dyn std::error::Error>> {
+ let args = argparse::parse_args(
+ &HashMap::from_iter([("--help", Flag::Help), ("--version", Flag::Version)]),
+ &HashMap::from([
+ ("--port", Param::Port),
+ ("--address", Param::Address),
+ ("--timeout", Param::Timeout),
+ ("--max-connections", Param::MaxConnections),
+ ]),
+ )?;
+ if args.is_set(Flag::Help) {
+ print_doc();
+ return Ok(());
+ }
+ if args.is_set(Flag::Version) {
+ print_version();
+ return Ok(());
+ }
+ let addr = args.get(Param::Address).unwrap_or("0.0.0.0");
+ let port = args.get(Param::Port).unwrap_or("80");
+ let port = port
+ .parse::<u16>()
+ .map_err(|_| format!("Invalid port: {port}"))?;
+ let timeout = args.get(Param::Timeout).unwrap_or("15");
+ let timeout = timeout
+ .parse::<u32>()
+ .map_err(|_| format!("Invalid timeout: {timeout}"))?;
+ let max_connections = args.get(Param::MaxConnections).unwrap_or("32");
+ let max_connections = max_connections
+ .parse::<u32>()
+ .map_err(|_| format!("Invalid max connections: {max_connections}"))?;
+
+ let listener = TcpListener::bind((addr, port))
+ .map_err(|e| format!("Couldn't bind on {addr}:{port}: {e}"))?;
+
+ let interrupted = Arc::new(AtomicBool::new(false));
+ let handler_interrupted = interrupted.clone();
+ if let Err(e) = ctrlc::set_handler(move || {
+ handler_interrupted.store(true, Ordering::Relaxed);
+ }) {
+ eprintln!("Warning: Couldn't set SIGINT/TERM handler: {e}");
+ }
+ listener
+ .set_nonblocking(true)
+ .map_err(|e| format!("Couldn't set socket to non-blocking: {e}"))?;
+
+ let was_interrupted = || interrupted.load(Ordering::Relaxed);
+ let mut connections: HashMap<IpAddr, Connection> = HashMap::new();
+ let mut settings = ConnectionSettings {
+ timeout,
+ buffer: [0; 1024],
+ };
+
+ 'outer: while !was_interrupted() {
+ sleep(Duration::from_millis(SLEEP_MS_PER_ROUND.into())); // don't busy loop
+ while (connections.len() as u32) < max_connections {
+ match listener.accept() {
+ Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
+ if was_interrupted() {
+ break 'outer;
+ }
+ break;
+ }
+ Err(e) => eprintln!("Warning: accept() failed: {e}"),
+ Ok((stream, source_addr)) => {
+ if let Err(e) = stream.set_nonblocking(true) {
+ eprintln!(
+ "WARNING: dropping connection because set_nonblocking failed: {e}"
+ );
+ continue;
+ }
+ // if there was another connection from this address, it gets
+ // unceremoniously dropped. too bad.
+ connections.insert(
+ source_addr.ip(),
+ Connection {
+ stream,
+ age_ms: 0,
+ written: 0,
+ state: Default::default(),
+ },
+ );
+ }
+ }
+ }
+ connections.retain(|addr, conn| settings.handle_connection(*addr, conn));
+ }
+ Ok(())
+}
+
+fn main() -> ExitCode {
+ if let Err(e) = try_main() {
+ eprintln!("Error: {e}");
+ return ExitCode::FAILURE;
+ }
+ ExitCode::SUCCESS
+}