summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpommicket <pommicket@gmail.com>2023-09-05 14:37:16 -0400
committerpommicket <pommicket@gmail.com>2023-09-05 14:37:16 -0400
commit2fb5acc305d94a3827bd11eb8c72b68daa0a2a6a (patch)
tree10b1373ab7b1895c0a8ca6b1564760de164353f2
parent19f2fdc726c531d5bbc05fbd0ea1445f61208ffb (diff)
clean up code a bit
-rw-r--r--README.md2
-rw-r--r--src/lib.rs407
2 files changed, 223 insertions, 186 deletions
diff --git a/README.md b/README.md
index 5011829..f74b727 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@ Also it has tiny code size (e.g. &gt;8x smaller `.wasm.gz` size compared to the
## Goals
-- Correctly decode all valid non-interlaced PNG files (on 32-bit platforms, some very large images
+- Correctly decode all valid non-interlaced PNG files (on ≤32-bit platforms, some very large images
might fail because of `usize::MAX`).
- Small code size &amp; complexity
- No dependencies other than `core`
diff --git a/src/lib.rs b/src/lib.rs
index eb41034..50547a8 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -379,8 +379,7 @@ impl ImageHeader {
}
fn data_size(&self) -> usize {
- let row_bytes = self.bytes_per_row();
- row_bytes * self.height() as usize
+ self.bytes_per_row() * self.height() as usize
}
}
@@ -406,15 +405,19 @@ impl<'a, R: Read> From<IdatReader<'a, R>> for BitReader<'a, R> {
}
impl<R: Read> BitReader<'_, R> {
+ fn read_more_bits(&mut self) -> Result<(), Error<R::Error>> {
+ let mut new_bits = [0; ReadBits::BITS as usize / 8];
+ self.inner.read_partial(&mut new_bits)?;
+ let new_bits = Bits::from(ReadBits::from_le_bytes(new_bits));
+ self.bits |= new_bits << self.bits_left;
+ self.bits_left += ReadBits::BITS as u8;
+ Ok(())
+ }
+
fn peek_bits(&mut self, count: u8) -> Result<u32, Error<R::Error>> {
debug_assert!(count > 0 && u32::from(count) <= 31);
if self.bits_left < count {
- // read more bits
- let mut new_bits = [0; ReadBits::BITS as usize / 8];
- self.inner.read_partial(&mut new_bits)?;
- let new_bits = Bits::from(ReadBits::from_le_bytes(new_bits));
- self.bits |= new_bits << self.bits_left;
- self.bits_left += ReadBits::BITS as u8;
+ self.read_more_bits()?;
}
Ok((self.bits as u32) & ((1 << count) - 1))
}
@@ -515,7 +518,7 @@ const HUFFMAN_MAIN_TABLE_SIZE: usize = 1 << HUFFMAN_MAIN_TABLE_BITS;
/// which is just the encoded value and length.
/// for long codes, the look-up table returns a position in the tree
/// to start from.
-#[derive(Debug)]
+#[derive(Debug, Clone, Copy)]
struct HuffmanTable {
main_table: [i16; HUFFMAN_MAIN_TABLE_SIZE],
tree: [i16; HUFFMAN_MAX_CODES * 2 + 1],
@@ -593,17 +596,23 @@ impl HuffmanTable {
table
}
- fn read_value<R: Read>(&self, reader: &mut BitReader<'_, R>) -> Result<u16, Error<R::Error>> {
- let mut code = reader.peek_bits(HUFFMAN_MAX_BITS)? as u16;
- let mut entry = self.main_table[usize::from(code) & (HUFFMAN_MAIN_TABLE_SIZE - 1)];
- if entry < 0 {
- code >>= HUFFMAN_MAIN_TABLE_BITS;
- while entry < 0 {
- entry = self.tree[usize::from(code & 1) + (-entry) as usize];
- code >>= 1;
- }
+ fn lookup_slow(&self, mut entry: i16, mut code: u16) -> u16 {
+ code >>= HUFFMAN_MAIN_TABLE_BITS;
+ while entry < 0 {
+ entry = self.tree[usize::from(code & 1) + (-entry) as usize];
+ code >>= 1;
}
- let entry = entry as u16;
+ entry as u16
+ }
+
+ fn read_value<R: Read>(&self, reader: &mut BitReader<'_, R>) -> Result<u16, Error<R::Error>> {
+ let code = reader.peek_bits(HUFFMAN_MAX_BITS)? as u16;
+ let entry = self.main_table[usize::from(code) & (HUFFMAN_MAIN_TABLE_SIZE - 1)];
+ let entry = if entry > 0 {
+ entry as u16
+ } else {
+ self.lookup_slow(entry, code)
+ };
let length = (entry >> 9) as u8;
if length == 0 {
return Err(Error::BadCode);
@@ -899,98 +908,147 @@ pub fn read_png_header<R: Read>(reader: &mut R) -> Result<ImageHeader, Error<R::
Ok(hdr)
}
-fn read_compressed_block<R: Read>(
+fn read_dynamic_huffman_dictionary<R: Read>(
reader: &mut BitReader<'_, R>,
- writer: &mut DecompressedDataWriter,
- dynamic: bool,
-) -> Result<(), Error<R::Error>> {
- let literal_length_table;
- let distance_table;
-
- if dynamic {
- let literal_length_code_lengths_count = reader.read_bits_usize(5)? + 257;
- let distance_code_lengths_count = reader.read_bits_usize(5)? + 1;
- let code_length_code_lengths_count = reader.read_bits_usize(4)? + 4;
- let mut code_length_code_lengths = [0; 19];
- for i in 0..code_length_code_lengths_count {
- const ORDER: [u8; 19] = [
- 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15,
- ];
- code_length_code_lengths[usize::from(ORDER[i])] = reader.read_bits_u8(3)?;
- }
- let code_length_table = HuffmanTable::from_code_lengths(&code_length_code_lengths);
- let mut code_lengths = [0; 286 + 32];
- let mut i = 0;
- let total_code_lengths = literal_length_code_lengths_count + distance_code_lengths_count;
- loop {
- let op = code_length_table.read_value(reader)? as u8;
- if op < 16 {
- code_lengths[i] = op;
+) -> Result<(HuffmanTable, HuffmanTable), Error<R::Error>> {
+ let literal_length_code_lengths_count = reader.read_bits_usize(5)? + 257;
+ let distance_code_lengths_count = reader.read_bits_usize(5)? + 1;
+ let code_length_code_lengths_count = reader.read_bits_usize(4)? + 4;
+ let mut code_length_code_lengths = [0; 19];
+ for i in 0..code_length_code_lengths_count {
+ const ORDER: [u8; 19] = [
+ 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15,
+ ];
+ code_length_code_lengths[usize::from(ORDER[i])] = reader.read_bits_u8(3)?;
+ }
+ let code_length_table = HuffmanTable::from_code_lengths(&code_length_code_lengths);
+ let mut code_lengths = [0; 286 + 32];
+ let mut i = 0;
+ let total_code_lengths = literal_length_code_lengths_count + distance_code_lengths_count;
+ loop {
+ let op = code_length_table.read_value(reader)? as u8;
+ if op < 16 {
+ code_lengths[i] = op;
+ i += 1;
+ } else if op == 16 {
+ let rep = reader.read_bits_usize(2)? + 3;
+ if i == 0 || i + rep > total_code_lengths {
+ return Err(Error::BadHuffmanDict);
+ }
+ let l = code_lengths[i - 1];
+ for _ in 0..rep {
+ code_lengths[i] = l;
i += 1;
- } else if op == 16 {
- let rep = reader.read_bits_usize(2)? + 3;
- if i == 0 || i + rep > total_code_lengths {
- return Err(Error::BadHuffmanDict);
- }
- let l = code_lengths[i - 1];
- for _ in 0..rep {
- code_lengths[i] = l;
- i += 1;
- }
- } else if op == 17 {
- let rep = reader.read_bits_usize(3)? + 3;
- if i + rep > total_code_lengths {
- return Err(Error::BadHuffmanDict);
- }
- for _ in 0..rep {
- code_lengths[i] = 0;
- i += 1;
- }
- } else if op == 18 {
- let rep = reader.read_bits_usize(7)? + 11;
- if i + rep > total_code_lengths {
- return Err(Error::BadHuffmanDict);
- }
- for _ in 0..rep {
- code_lengths[i] = 0;
- i += 1;
- }
- } else {
- // since we only assigned 0..=18 in the huffman table,
- // we should never get a value outside that range.
- debug_assert!(false, "should not be reachable");
}
- if i >= total_code_lengths {
- break;
+ } else if op == 17 {
+ let rep = reader.read_bits_usize(3)? + 3;
+ if i + rep > total_code_lengths {
+ return Err(Error::BadHuffmanDict);
}
+ for _ in 0..rep {
+ code_lengths[i] = 0;
+ i += 1;
+ }
+ } else if op == 18 {
+ let rep = reader.read_bits_usize(7)? + 11;
+ if i + rep > total_code_lengths {
+ return Err(Error::BadHuffmanDict);
+ }
+ for _ in 0..rep {
+ code_lengths[i] = 0;
+ i += 1;
+ }
+ } else {
+ // since we only assigned 0..=18 in the huffman table,
+ // we should never get a value outside that range.
+ debug_assert!(false, "should not be reachable");
}
- let literal_length_code_lengths = &code_lengths[0..literal_length_code_lengths_count];
- let distance_code_lengths =
- &code_lengths[literal_length_code_lengths_count..total_code_lengths];
- literal_length_table = HuffmanTable::from_code_lengths(literal_length_code_lengths);
- distance_table = HuffmanTable::from_code_lengths(distance_code_lengths);
- } else {
- let mut lit = HuffmanTable::default();
- let mut dist = HuffmanTable::default();
- for i in 0..=143 {
- lit.assign(0b00110000 + i, 8, i);
- }
- for i in 144..=255 {
- lit.assign(0b110010000 + (i - 144), 9, i);
- }
- for i in 256..=279 {
- lit.assign(i - 256, 7, i);
- }
- for i in 280..=287 {
- lit.assign(0b11000000 + (i - 280), 8, i);
- }
- for i in 0..30 {
- dist.assign(i, 5, i);
+ if i >= total_code_lengths {
+ break;
}
+ }
+ let literal_length_code_lengths = &code_lengths[0..min(literal_length_code_lengths_count, 286)];
+ let distance_code_lengths = &code_lengths[literal_length_code_lengths_count
+ ..min(total_code_lengths, literal_length_code_lengths_count + 30)];
+ Ok((
+ HuffmanTable::from_code_lengths(literal_length_code_lengths),
+ HuffmanTable::from_code_lengths(distance_code_lengths),
+ ))
+}
- literal_length_table = lit;
- distance_table = dist;
+fn get_fixed_huffman_dictionaries() -> (HuffmanTable, HuffmanTable) {
+ let mut lit = HuffmanTable::default();
+ let mut dist = HuffmanTable::default();
+ for i in 0..=143 {
+ lit.assign(0b00110000 + i, 8, i);
+ }
+ for i in 144..=255 {
+ lit.assign(0b110010000 + (i - 144), 9, i);
+ }
+ for i in 256..=279 {
+ lit.assign(i - 256, 7, i);
}
+ for i in 280..=285 {
+ lit.assign(0b11000000 + (i - 280), 8, i);
+ }
+ for i in 0..30 {
+ dist.assign(i, 5, i);
+ }
+ (lit, dist)
+}
+
+fn read_compressed_block<R: Read>(
+ reader: &mut BitReader<'_, R>,
+ writer: &mut DecompressedDataWriter,
+ dynamic: bool,
+) -> Result<(), Error<R::Error>> {
+ let (literal_length_table, distance_table) = if dynamic {
+ read_dynamic_huffman_dictionary(reader)?
+ } else {
+ get_fixed_huffman_dictionaries()
+ };
+
+ fn parse_length<R: Read>(
+ reader: &mut BitReader<'_, R>,
+ literal_length: u16,
+ ) -> Result<u16, Error<R::Error>> {
+ Ok(match literal_length {
+ 257..=264 => literal_length - 254,
+ 265..=284 => {
+ const BASES: [u8; 20] = [
+ 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195,
+ 227,
+ ];
+ let base: u16 = BASES[usize::from(literal_length - 265)].into();
+ let extra_bits = (literal_length - 261) as u8 / 4;
+ let extra = reader.read_bits_u16(extra_bits)?;
+ base + extra
+ }
+ 285 => 258,
+ _ => unreachable!(), // we only could've assigned up to 285.
+ })
+ }
+
+ fn parse_distance<R: Read>(
+ reader: &mut BitReader<'_, R>,
+ distance_code: u16,
+ ) -> Result<u16, Error<R::Error>> {
+ Ok(match distance_code {
+ 0..=3 => distance_code + 1,
+ 4..=29 => {
+ const BASES: [u16; 26] = [
+ 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537,
+ 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577,
+ ];
+ let base = BASES[usize::from(distance_code - 4)];
+ let extra_bits = (distance_code - 2) as u8 / 2;
+ let extra = reader.read_bits_u16(extra_bits)?;
+ base + extra
+ }
+ _ => unreachable!(), // we only could've assigned up to 29.
+ })
+ }
+
loop {
let literal_length = literal_length_table.read_value(reader)?;
match literal_length {
@@ -998,50 +1056,42 @@ fn read_compressed_block<R: Read>(
// literal
writer.write_byte(literal_length as u8)?;
}
- 256 => {
- // end of block
- break;
- }
- _ => {
+ 257.. => {
// length + distance
- let length = match literal_length {
- 257..=264 => literal_length - 254,
- 265..=284 => {
- const BASES: [u8; 20] = [
- 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131,
- 163, 195, 227,
- ];
- let base: u16 = BASES[usize::from(literal_length - 265)].into();
- let extra_bits = (literal_length - 261) as u8 / 4;
- let extra = reader.read_bits_u16(extra_bits)?;
- base + extra
- }
- 285 => 258,
- _ => return Err(Error::BadCode),
- };
-
+ let length = parse_length(reader, literal_length)?;
let distance_code = distance_table.read_value(reader)?;
- let distance = match distance_code {
- 0..=3 => distance_code + 1,
- 4..=29 => {
- const BASES: [u16; 26] = [
- 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769,
- 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577,
- ];
- let base = BASES[usize::from(distance_code - 4)];
- let extra_bits = (distance_code - 2) as u8 / 2;
- let extra = reader.read_bits_u16(extra_bits)?;
- base + extra
- }
- _ => return Err(Error::BadCode),
- };
+ let distance = parse_distance(reader, distance_code)?;
writer.copy(usize::from(distance), usize::from(length))?;
}
+ 256 => {
+ // end of block
+ break;
+ }
}
}
Ok(())
}
+fn read_uncompressed_block<R: Read>(
+ reader: &mut BitReader<'_, R>,
+ writer: &mut DecompressedDataWriter,
+) -> Result<(), Error<R::Error>> {
+ reader.bits >>= reader.bits_left % 8;
+ reader.bits_left -= reader.bits_left % 8;
+ let len = reader.read_bits_u16(16)?;
+ let nlen = reader.read_bits_u16(16)?;
+ if len ^ nlen != 0xffff {
+ return Err(Error::BadNlen);
+ }
+ let len: usize = len.into();
+ if len > writer.slice.len() - writer.pos {
+ return Err(Error::TooMuchData);
+ }
+ reader.read_aligned_bytes(&mut writer.slice[writer.pos..writer.pos + len])?;
+ writer.pos += len;
+ Ok(())
+}
+
fn read_idat<R: Read>(
reader: IdatReader<'_, R>,
writer: &mut DecompressedDataWriter,
@@ -1059,35 +1109,28 @@ fn read_idat<R: Read>(
if compression_method != 8 || compression_info > 7 {
return Err(Error::BadZlibHeader);
}
+ // no preset dictionary
if (flags & 0x100) != 0 {
return Err(Error::BadZlibHeader);
}
let decompressed_size = reader.inner.header.decompressed_size();
- while writer.pos < decompressed_size {
+ loop {
let bfinal = reader.read_bits(1)?;
let btype = reader.read_bits(2)?;
- if btype == 0 {
- // uncompressed block
- reader.bits >>= reader.bits_left % 8;
- reader.bits_left -= reader.bits_left % 8;
- let len = reader.read_bits_u16(16)?;
- let nlen = reader.read_bits_u16(16)?;
- if len ^ nlen != 0xffff {
- return Err(Error::BadNlen);
+ match btype {
+ 0 => {
+ // uncompressed block
+ read_uncompressed_block(&mut reader, writer)?;
}
- let len: usize = len.into();
- if len > writer.slice.len() - writer.pos {
- return Err(Error::TooMuchData);
+ 1 | 2 => {
+ // compressed block
+ read_compressed_block(&mut reader, writer, btype == 2)?;
+ }
+ _ => {
+ // 0b11 is not a valid block type
+ return Err(Error::BadBlockType);
}
- reader.read_aligned_bytes(&mut writer.slice[writer.pos..writer.pos + len])?;
- writer.pos += len;
- } else if btype == 1 || btype == 2 {
- // compressed block
- read_compressed_block(&mut reader, writer, btype == 2)?;
- } else {
- // 0b11 is not a valid block type
- return Err(Error::BadBlockType);
}
if bfinal != 0 {
break;
@@ -1148,21 +1191,6 @@ fn apply_filters<I: IOError>(header: &ImageHeader, data: &mut [u8]) -> Result<()
const FILTER_AVG: u8 = 3;
const FILTER_PAETH: u8 = 4;
- #[inline]
- fn paeth(a: u8, b: u8, c: u8) -> u8 {
- let p = i32::from(a) + i32::from(b) - i32::from(c);
- let pa = (p - i32::from(a)).abs();
- let pb = (p - i32::from(b)).abs();
- let pc = (p - i32::from(c)).abs();
- if pa <= pb && pa <= pc {
- a
- } else if pb <= pc {
- b
- } else {
- c
- }
- }
-
s += 1;
data.copy_within(s..s + scanline_bytes, d);
match (filter, scanline == 0) {
@@ -1195,19 +1223,30 @@ fn apply_filters<I: IOError>(header: &ImageHeader, data: &mut [u8]) -> Result<()
}
(FILTER_PAETH, false) => {
for i in d..d + x_byte_offset {
- data[i] = data[i].wrapping_add(paeth(0, data[i - scanline_bytes], 0));
+ data[i] = data[i].wrapping_add(data[i - scanline_bytes]);
}
for i in d + x_byte_offset..d + scanline_bytes {
- data[i] = data[i].wrapping_add(paeth(
- data[i - x_byte_offset],
- data[i - scanline_bytes],
- data[i - scanline_bytes - x_byte_offset],
- ));
+ let a = data[i - x_byte_offset];
+ let b = data[i - scanline_bytes];
+ let c = data[i - scanline_bytes - x_byte_offset];
+
+ let p = i32::from(a) + i32::from(b) - i32::from(c);
+ let pa = (p - i32::from(a)).abs();
+ let pb = (p - i32::from(b)).abs();
+ let pc = (p - i32::from(c)).abs();
+ let paeth = if pa <= pb && pa <= pc {
+ a
+ } else if pb <= pc {
+ b
+ } else {
+ c
+ };
+ data[i] = data[i].wrapping_add(paeth);
}
}
(FILTER_PAETH, true) => {
for i in d + x_byte_offset..d + scanline_bytes {
- data[i] = data[i].wrapping_add(paeth(data[i - x_byte_offset], 0, 0));
+ data[i] = data[i].wrapping_add(data[i - x_byte_offset]);
}
}
(5.., _) => return Err(Error::BadFilter),
@@ -1254,8 +1293,7 @@ fn read_non_idat_chunks<R: Read>(
for i in 0..count {
palette[i][0..3].copy_from_slice(&data[3 * i..3 * i + 3]);
}
- // checksum
- reader.skip_bytes(4)?;
+ reader.skip_bytes(4)?; // CRC
} else if &chunk_type == b"tRNS" && header.color_type == ColorType::Indexed {
if chunk_len > 256 {
return Err(Error::BadTrnsChunk);
@@ -1265,8 +1303,7 @@ fn read_non_idat_chunks<R: Read>(
for i in 0..chunk_len {
palette[i][3] = data[i];
}
- // checksum
- reader.skip_bytes(4)?;
+ reader.skip_bytes(4)?; // CRC
} else if (chunk_type[0] & 0x20) != 0 || &chunk_type == b"PLTE" {
// non-essential chunk
reader.skip_bytes(chunk_len + 4)?;
@@ -1322,8 +1359,8 @@ pub fn read_png<'a, R: Read>(
palette[1] = [255, 255, 255, 255];
}
BitDepth::Two => {
- #[allow(clippy::needless_range_loop)]
// clippy's suggestion here is more unreadable imo
+ #[allow(clippy::needless_range_loop)]
for i in 0..4 {
let v = (255 * i / 3) as u8;
palette[i] = [v, v, v, 255];