From af16284921b65f92601279001531862d1c80cee7 Mon Sep 17 00:00:00 2001 From: pommicket Date: Fri, 29 Aug 2025 04:25:57 -0400 Subject: Initial commit (v. 0.1.0) --- src/argparse.rs | 82 +++++++++++++++++++++ src/main.rs | 222 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 304 insertions(+) create mode 100644 src/argparse.rs create mode 100644 src/main.rs (limited to 'src') diff --git a/src/argparse.rs b/src/argparse.rs new file mode 100644 index 0000000..8fd6ae5 --- /dev/null +++ b/src/argparse.rs @@ -0,0 +1,82 @@ +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::hash::Hash; + +#[derive(Debug)] +pub struct Args { + flags: HashSet, + params: HashMap, + lone_args: Vec, +} + +impl Args { + pub fn is_set(&self, flag: Flag) -> bool { + self.flags.contains(&flag) + } + pub fn get(&self, param: Param) -> Option<&str> { + Some(self.params.get(¶m)?.as_ref()) + } + pub fn lone_args(&self) -> impl '_ + Iterator { + self.lone_args.iter().map(|x| x.as_ref()) + } + pub fn lone_args_count(&self) -> usize { + self.lone_args.len() + } +} + +pub fn parse_args( + flag_names: &HashMap<&str, Flag>, + param_names: &HashMap<&str, Param>, +) -> Result, String> { + let mut arg_iter = std::env::args_os(); + arg_iter.next(); // program name + let mut args = Args { + flags: HashSet::new(), + params: HashMap::new(), + lone_args: vec![], + }; + let mut double_dash = false; + let mut param: Option<(String, Param)> = None; + assert!(flag_names.keys().all(|x| x.starts_with('-'))); + assert!(param_names.keys().all(|x| x.starts_with('-'))); + for arg in arg_iter { + let arg = arg + .into_string() + .map_err(|arg| format!("Argument includes bad UTF-8: {arg:?}"))?; + if let Some((_name, p)) = param.as_ref() { + args.params.insert(*p, arg); + param = None; + continue; + } + if double_dash { + args.lone_args.push(arg); + continue; + } + if arg == "--" { + double_dash = true; + continue; + } + if let Some(flag) = flag_names.get(arg.as_str()) { + args.flags.insert(*flag); + continue; + } + if let Some(p) = param_names.get(arg.as_str()) { + param = Some((arg, *p)); + continue; + } + if let Some((p, value)) = arg.split_once('=') + && let Some(p) = param_names.get(p) + { + args.params.insert(*p, value.into()); + continue; + } + if arg.starts_with('-') { + return Err(format!("Unrecognized option: {arg}")); + } + args.lone_args.push(arg); + } + if let Some((name, _p)) = param { + return Err(format!("No argument provided to {name}")); + } + Ok(args) +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..ffa1892 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,222 @@ +pub mod argparse; + +use std::collections::HashMap; +use std::io::{self, prelude::*}; +use std::net::{IpAddr, TcpListener, TcpStream}; +use std::process::ExitCode; +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; +use std::thread::sleep; +use std::time::Duration; + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +enum Flag { + Help, + Version, +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +enum Param { + Port, + Address, + Timeout, + MaxConnections, +} + +fn print_version() { + println!( + "{} v. {}", + env!("CARGO_PKG_NAME"), + env!("CARGO_PKG_VERSION") + ); +} + +fn print_doc() { + print_version(); + println!("{}", include_str!("../README.txt")); +} + +#[derive(Debug)] +struct ConnectionSettings { + timeout: u32, + buffer: [u8; 1024], +} + +const SLEEP_MS_PER_ROUND: u32 = 50; + +impl ConnectionSettings { + /// handle this connection as much as possible *without blocking* + #[must_use] + fn handle_connection(&mut self, addr: IpAddr, conn: &mut Connection) -> bool { + let buffer = &mut self.buffer; + conn.age_ms += SLEEP_MS_PER_ROUND; + if conn.age_ms > self.timeout * 1000 { + return false; + } + if conn.state < 4 { + let n = match conn.stream.read(buffer) { + Ok(n) => n, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => 0, + Err(e) => { + eprintln!("WARNING: error reading from connection: {e}"); + return false; + } + }; + for &c in &buffer[..n] { + if c != b'\r' && c != b'\n' { + continue; + } + let expected = if conn.state % 2 == 0 { b'\r' } else { b'\n' }; + if c == expected { + conn.state += 1; + if conn.state == 4 { + break; + } + } + } + } + if conn.state == 4 { + // 90B is more than long enough for our response + write!( + &mut buffer[..90], + "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n{addr}\n\0" + ) + .expect("?? writing address to buffer failed??"); + let length = buffer[..90] + .iter() + .position(|&c| c == 0) + .expect("?? no null byte even though we just wrote one??"); + match conn + .stream + .write(&buffer[usize::from(conn.written)..length]) + { + Ok(n) => { + conn.written += n as u8; + if conn.written >= length as u8 { + // hooray! done sending response + return false; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + // keep connection around + } + Err(e) => { + eprintln!("WARNING: error writing to connection: {e}"); + return false; + } + } + } + true + } +} + +#[derive(Debug)] +struct Connection { + stream: TcpStream, + /// rough time connection has been around in milliseconds + age_ms: u32, + /// number of characters out of \r\n\r\n (end of HTTP headers) that we have + state: u8, + /// number of characters of address which have been written + written: u8, +} + +fn try_main() -> Result<(), Box> { + let args = argparse::parse_args( + &HashMap::from_iter([("--help", Flag::Help), ("--version", Flag::Version)]), + &HashMap::from([ + ("--port", Param::Port), + ("--address", Param::Address), + ("--timeout", Param::Timeout), + ("--max-connections", Param::MaxConnections), + ]), + )?; + if args.is_set(Flag::Help) { + print_doc(); + return Ok(()); + } + if args.is_set(Flag::Version) { + print_version(); + return Ok(()); + } + let addr = args.get(Param::Address).unwrap_or("0.0.0.0"); + let port = args.get(Param::Port).unwrap_or("80"); + let port = port + .parse::() + .map_err(|_| format!("Invalid port: {port}"))?; + let timeout = args.get(Param::Timeout).unwrap_or("15"); + let timeout = timeout + .parse::() + .map_err(|_| format!("Invalid timeout: {timeout}"))?; + let max_connections = args.get(Param::MaxConnections).unwrap_or("32"); + let max_connections = max_connections + .parse::() + .map_err(|_| format!("Invalid max connections: {max_connections}"))?; + + let listener = TcpListener::bind((addr, port)) + .map_err(|e| format!("Couldn't bind on {addr}:{port}: {e}"))?; + + let interrupted = Arc::new(AtomicBool::new(false)); + let handler_interrupted = interrupted.clone(); + if let Err(e) = ctrlc::set_handler(move || { + handler_interrupted.store(true, Ordering::Relaxed); + }) { + eprintln!("Warning: Couldn't set SIGINT/TERM handler: {e}"); + } + listener + .set_nonblocking(true) + .map_err(|e| format!("Couldn't set socket to non-blocking: {e}"))?; + + let was_interrupted = || interrupted.load(Ordering::Relaxed); + let mut connections: HashMap = HashMap::new(); + let mut settings = ConnectionSettings { + timeout, + buffer: [0; 1024], + }; + + 'outer: while !was_interrupted() { + sleep(Duration::from_millis(SLEEP_MS_PER_ROUND.into())); // don't busy loop + while (connections.len() as u32) < max_connections { + match listener.accept() { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + if was_interrupted() { + break 'outer; + } + break; + } + Err(e) => eprintln!("Warning: accept() failed: {e}"), + Ok((stream, source_addr)) => { + if let Err(e) = stream.set_nonblocking(true) { + eprintln!( + "WARNING: dropping connection because set_nonblocking failed: {e}" + ); + continue; + } + // if there was another connection from this address, it gets + // unceremoniously dropped. too bad. + connections.insert( + source_addr.ip(), + Connection { + stream, + age_ms: 0, + written: 0, + state: Default::default(), + }, + ); + } + } + } + connections.retain(|addr, conn| settings.handle_connection(*addr, conn)); + } + Ok(()) +} + +fn main() -> ExitCode { + if let Err(e) = try_main() { + eprintln!("Error: {e}"); + return ExitCode::FAILURE; + } + ExitCode::SUCCESS +} -- cgit v1.2.3