diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 124 |
1 files changed, 108 insertions, 16 deletions
@@ -10,6 +10,7 @@ impl<T: Sized + Display + Debug> IOError for T {} #[non_exhaustive] pub enum Error<I: IOError> { IO(I), + BufferTooSmall, NotPng, BadIhdr, UnrecognizedChunk([u8; 4]), @@ -27,6 +28,7 @@ pub enum Error<I: IOError> { BadTrnsChunk, BadNlen, NoIdat, + BadAdlerChecksum, } impl<I: IOError> From<I> for Error<I> { @@ -41,6 +43,7 @@ impl<I: IOError> Display for Error<I> { Self::IO(e) => write!(f, "{e}"), Self::NotPng => write!(f, "not a png file"), Self::BadIhdr => write!(f, "bad IHDR chunk"), + Self::BufferTooSmall => write!(f, "provided buffer is too small"), Self::UnrecognizedChunk([a, b, c, d]) => { write!(f, "unrecognized chunk type: {a} {b} {c} {d}") } @@ -60,6 +63,7 @@ impl<I: IOError> Display for Error<I> { Self::BadTrnsChunk => write!(f, "bad tRNS chunk"), Self::NoIdat => write!(f, "missing IDAT chunk"), Self::BadNlen => write!(f, "LEN doesn't match NLEN"), + Self::BadAdlerChecksum => write!(f, "bad adler-32 checksum"), } } } @@ -131,6 +135,7 @@ struct IdatReader<'a, R: Read> { bytes_left_in_block: usize, palette: &'a mut [[u8; 4]; 256], header: &'a ImageHeader, + eof: bool, } impl<R: Read> IdatReader<'_, R> { @@ -147,8 +152,13 @@ impl<R: Read> IdatReader<'_, R> { // CRC self.inner.skip_bytes(4)?; + match read_non_idat_chunks(self.inner, self.header, self.palette)? { - None => Ok(bytes_read), + None => { + self.bytes_left_in_block = 0; + self.eof = true; + Ok(bytes_read) + } Some(n) => { self.bytes_left_in_block = n; Ok(self.read_partial(&mut buf[bytes_read..])? + bytes_read) @@ -167,15 +177,20 @@ impl<R: Read> IdatReader<'_, R> { } fn read_to_end(&mut self) -> Result<(), Error<R::Error>> { - self.inner.skip_bytes(self.bytes_left_in_block)?; - // CRC - self.inner.skip_bytes(4)?; - loop { - match read_non_idat_chunks(self.inner, self.header, self.palette)? { - None => break, - Some(n) => self.inner.skip_bytes(n + 4)?, + if !self.eof { + if self.bytes_left_in_block > 0 { + self.inner.skip_bytes(self.bytes_left_in_block)?; + } + // CRC + self.inner.skip_bytes(4)?; + loop { + match read_non_idat_chunks(self.inner, self.header, self.palette)? { + None => break, + Some(n) => self.inner.skip_bytes(n + 4)?, + } } } + self.eof = true; Ok(()) } } @@ -594,6 +609,8 @@ fn read_compressed_block<R: Read>( i += 1; } } else { + // since we only assigned 0..=18 in the huffman table, + // we should never get a value outside that range. debug_assert!(false, "should not be reachable"); } if i >= total_code_lengths { @@ -643,11 +660,11 @@ fn read_compressed_block<R: Read>( let length = match literal_length { 257..=264 => literal_length - 254, 265..=284 => { - const BASES: [u16; 20] = [ + const BASES: [u8; 20] = [ 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, ]; - let base = BASES[usize::from(literal_length - 265)]; + let base: u16 = BASES[usize::from(literal_length - 265)].into(); let extra_bits = (literal_length - 261) as u8 / 4; let extra = reader.read_bits_u16(extra_bits)?; base + extra @@ -683,7 +700,22 @@ fn read_idat<R: Read>( writer: &mut DecompressedDataWriter, ) -> Result<(), Error<R::Error>> { let mut reader = BitReader::from(reader); - let _zlib_header = reader.read_bits(16); + // zlib header + let cmf = reader.read_bits(8)?; + let flags = reader.read_bits(8)?; + // check zlib checksum + if (cmf * 256 + flags) % 31 != 0 { + return Err(Error::BadZlibHeader); + } + let compression_method = cmf & 0xf; + let compression_info = cmf >> 4; + if compression_method != 8 || compression_info > 7 { + return Err(Error::BadZlibHeader); + } + if (flags & 0x100) != 0 { + return Err(Error::BadZlibHeader); + } + let decompressed_size = reader.inner.header.decompressed_size(); while writer.pos < decompressed_size { let bfinal = reader.read_bits(1)?; @@ -714,6 +746,40 @@ fn read_idat<R: Read>( break; } } + + #[cfg(feature = "adler")] + { + // adler32 checksum + let padding = reader.bits_left % 8; + if padding > 0 { + reader.bits >>= padding; + reader.bits_left -= padding; + } + // NOTE: currently `read_bits` doesn't support reads of 32 bits. + let mut expected_adler = reader.read_bits(16)?; + expected_adler |= reader.read_bits(16)? << 16; + expected_adler = expected_adler.swap_bytes(); + + const BASE: u32 = 65521; + let mut s1: u32 = 1; + let mut s2: u32 = 0; + for byte in writer.slice[..decompressed_size].iter().copied() { + s1 += u32::from(byte); + if s1 > BASE { + s1 -= BASE; + } + s2 += s1; + if s2 > BASE { + s2 -= BASE; + } + } + let got_adler = s2 << 16 | s1; + if got_adler != expected_adler { + return Err(Error::BadAdlerChecksum); + } + } + + // padding bytes reader.inner.read_to_end()?; Ok(()) @@ -878,6 +944,9 @@ pub fn read_png<'a, R: Read>( None => read_png_header(reader)?, Some(h) => *h, }; + if buf.len() < header.required_bytes() { + return Err(Error::BufferTooSmall); + } let mut writer = DecompressedDataWriter::from(buf); let mut palette = [[0, 0, 0, 0]; 256]; let Some(idat_len) = read_non_idat_chunks(reader, &header, &mut palette)? else { @@ -889,6 +958,7 @@ pub fn read_png<'a, R: Read>( bytes_left_in_block: idat_len, header: &header, palette: &mut palette, + eof: false, }, &mut writer, )?; @@ -905,19 +975,22 @@ pub fn read_png<'a, R: Read>( #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "std")] use std::fs::File; + extern crate alloc; + #[cfg(feature = "std")] fn test_file(path: &str) { let decoder = png::Decoder::new(File::open(path).expect("file not found")); let mut reader = decoder.read_info().unwrap(); - let mut png_buf = vec![0; reader.output_buffer_size()]; + let mut png_buf = alloc::vec![0; reader.output_buffer_size()]; let png_header = reader.next_frame(&mut png_buf).unwrap(); let png_bytes = &png_buf[..png_header.buffer_size()]; let mut r = std::io::BufReader::new(File::open(path).expect("file not found")); let tiny_header = read_png_header(&mut r).unwrap(); - let mut tiny_buf = vec![0; tiny_header.required_bytes()]; + let mut tiny_buf = alloc::vec![0; tiny_header.required_bytes()]; let image = read_png(&mut r, Some(&tiny_header), &mut tiny_buf).unwrap(); let tiny_bytes = image.pixels(); @@ -929,12 +1002,12 @@ mod tests { let decoder = png::Decoder::new(bytes); let mut reader = decoder.read_info().unwrap(); - let mut png_buf = vec![0; reader.output_buffer_size()]; + let mut png_buf = alloc::vec![0; reader.output_buffer_size()]; let png_header = reader.next_frame(&mut png_buf).unwrap(); let png_bytes = &png_buf[..png_header.buffer_size()]; let tiny_header = read_png_header(&mut bytes).unwrap(); - let mut tiny_buf = vec![0; tiny_header.required_bytes()]; + let mut tiny_buf = alloc::vec![0; tiny_header.required_bytes()]; let image = read_png(&mut bytes, Some(&tiny_header), &mut tiny_buf).unwrap(); let tiny_bytes = image.pixels(); @@ -944,7 +1017,10 @@ mod tests { macro_rules! test_both { ($file:literal) => { - test_file($file); + #[cfg(feature = "std")] + { + test_file($file); + } test_bytes(include_bytes!(concat!("../", $file))); }; } @@ -1041,4 +1117,20 @@ mod tests { fn test_ouroboros() { test_both!("test/ouroboros.png"); } + #[test] + fn test_bad_png() { + let mut data = &b"hello"[..]; + // in this case we might actually get an unexpected EOF + assert!(read_png_header(&mut data).is_err()); + let mut data = &b"helloadfalskdfjlksajdflkjsadlkfj"[..]; + let err = read_png_header(&mut data).unwrap_err(); + assert!(matches!(err, Error::NotPng)); + } + #[test] + fn test_buffer_too_small() { + let mut data = &include_bytes!("../test/ouroboros.png")[..]; + let mut buffer = [0; 128]; + let err = read_png(&mut data, None, &mut buffer[..]).unwrap_err(); + assert!(matches!(err, Error::BufferTooSmall)); + } } |