From 2fb5acc305d94a3827bd11eb8c72b68daa0a2a6a Mon Sep 17 00:00:00 2001 From: pommicket Date: Tue, 5 Sep 2023 14:37:16 -0400 Subject: clean up code a bit --- src/lib.rs | 407 +++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 222 insertions(+), 185 deletions(-) (limited to 'src') diff --git a/src/lib.rs b/src/lib.rs index eb41034..50547a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -379,8 +379,7 @@ impl ImageHeader { } fn data_size(&self) -> usize { - let row_bytes = self.bytes_per_row(); - row_bytes * self.height() as usize + self.bytes_per_row() * self.height() as usize } } @@ -406,15 +405,19 @@ impl<'a, R: Read> From> for BitReader<'a, R> { } impl BitReader<'_, R> { + fn read_more_bits(&mut self) -> Result<(), Error> { + let mut new_bits = [0; ReadBits::BITS as usize / 8]; + self.inner.read_partial(&mut new_bits)?; + let new_bits = Bits::from(ReadBits::from_le_bytes(new_bits)); + self.bits |= new_bits << self.bits_left; + self.bits_left += ReadBits::BITS as u8; + Ok(()) + } + fn peek_bits(&mut self, count: u8) -> Result> { debug_assert!(count > 0 && u32::from(count) <= 31); if self.bits_left < count { - // read more bits - let mut new_bits = [0; ReadBits::BITS as usize / 8]; - self.inner.read_partial(&mut new_bits)?; - let new_bits = Bits::from(ReadBits::from_le_bytes(new_bits)); - self.bits |= new_bits << self.bits_left; - self.bits_left += ReadBits::BITS as u8; + self.read_more_bits()?; } Ok((self.bits as u32) & ((1 << count) - 1)) } @@ -515,7 +518,7 @@ const HUFFMAN_MAIN_TABLE_SIZE: usize = 1 << HUFFMAN_MAIN_TABLE_BITS; /// which is just the encoded value and length. /// for long codes, the look-up table returns a position in the tree /// to start from. -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] struct HuffmanTable { main_table: [i16; HUFFMAN_MAIN_TABLE_SIZE], tree: [i16; HUFFMAN_MAX_CODES * 2 + 1], @@ -593,17 +596,23 @@ impl HuffmanTable { table } - fn read_value(&self, reader: &mut BitReader<'_, R>) -> Result> { - let mut code = reader.peek_bits(HUFFMAN_MAX_BITS)? as u16; - let mut entry = self.main_table[usize::from(code) & (HUFFMAN_MAIN_TABLE_SIZE - 1)]; - if entry < 0 { - code >>= HUFFMAN_MAIN_TABLE_BITS; - while entry < 0 { - entry = self.tree[usize::from(code & 1) + (-entry) as usize]; - code >>= 1; - } + fn lookup_slow(&self, mut entry: i16, mut code: u16) -> u16 { + code >>= HUFFMAN_MAIN_TABLE_BITS; + while entry < 0 { + entry = self.tree[usize::from(code & 1) + (-entry) as usize]; + code >>= 1; } - let entry = entry as u16; + entry as u16 + } + + fn read_value(&self, reader: &mut BitReader<'_, R>) -> Result> { + let code = reader.peek_bits(HUFFMAN_MAX_BITS)? as u16; + let entry = self.main_table[usize::from(code) & (HUFFMAN_MAIN_TABLE_SIZE - 1)]; + let entry = if entry > 0 { + entry as u16 + } else { + self.lookup_slow(entry, code) + }; let length = (entry >> 9) as u8; if length == 0 { return Err(Error::BadCode); @@ -899,98 +908,147 @@ pub fn read_png_header(reader: &mut R) -> Result( +fn read_dynamic_huffman_dictionary( reader: &mut BitReader<'_, R>, - writer: &mut DecompressedDataWriter, - dynamic: bool, -) -> Result<(), Error> { - let literal_length_table; - let distance_table; - - if dynamic { - let literal_length_code_lengths_count = reader.read_bits_usize(5)? + 257; - let distance_code_lengths_count = reader.read_bits_usize(5)? + 1; - let code_length_code_lengths_count = reader.read_bits_usize(4)? + 4; - let mut code_length_code_lengths = [0; 19]; - for i in 0..code_length_code_lengths_count { - const ORDER: [u8; 19] = [ - 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15, - ]; - code_length_code_lengths[usize::from(ORDER[i])] = reader.read_bits_u8(3)?; - } - let code_length_table = HuffmanTable::from_code_lengths(&code_length_code_lengths); - let mut code_lengths = [0; 286 + 32]; - let mut i = 0; - let total_code_lengths = literal_length_code_lengths_count + distance_code_lengths_count; - loop { - let op = code_length_table.read_value(reader)? as u8; - if op < 16 { - code_lengths[i] = op; +) -> Result<(HuffmanTable, HuffmanTable), Error> { + let literal_length_code_lengths_count = reader.read_bits_usize(5)? + 257; + let distance_code_lengths_count = reader.read_bits_usize(5)? + 1; + let code_length_code_lengths_count = reader.read_bits_usize(4)? + 4; + let mut code_length_code_lengths = [0; 19]; + for i in 0..code_length_code_lengths_count { + const ORDER: [u8; 19] = [ + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15, + ]; + code_length_code_lengths[usize::from(ORDER[i])] = reader.read_bits_u8(3)?; + } + let code_length_table = HuffmanTable::from_code_lengths(&code_length_code_lengths); + let mut code_lengths = [0; 286 + 32]; + let mut i = 0; + let total_code_lengths = literal_length_code_lengths_count + distance_code_lengths_count; + loop { + let op = code_length_table.read_value(reader)? as u8; + if op < 16 { + code_lengths[i] = op; + i += 1; + } else if op == 16 { + let rep = reader.read_bits_usize(2)? + 3; + if i == 0 || i + rep > total_code_lengths { + return Err(Error::BadHuffmanDict); + } + let l = code_lengths[i - 1]; + for _ in 0..rep { + code_lengths[i] = l; i += 1; - } else if op == 16 { - let rep = reader.read_bits_usize(2)? + 3; - if i == 0 || i + rep > total_code_lengths { - return Err(Error::BadHuffmanDict); - } - let l = code_lengths[i - 1]; - for _ in 0..rep { - code_lengths[i] = l; - i += 1; - } - } else if op == 17 { - let rep = reader.read_bits_usize(3)? + 3; - if i + rep > total_code_lengths { - return Err(Error::BadHuffmanDict); - } - for _ in 0..rep { - code_lengths[i] = 0; - i += 1; - } - } else if op == 18 { - let rep = reader.read_bits_usize(7)? + 11; - if i + rep > total_code_lengths { - return Err(Error::BadHuffmanDict); - } - for _ in 0..rep { - code_lengths[i] = 0; - i += 1; - } - } else { - // since we only assigned 0..=18 in the huffman table, - // we should never get a value outside that range. - debug_assert!(false, "should not be reachable"); } - if i >= total_code_lengths { - break; + } else if op == 17 { + let rep = reader.read_bits_usize(3)? + 3; + if i + rep > total_code_lengths { + return Err(Error::BadHuffmanDict); } + for _ in 0..rep { + code_lengths[i] = 0; + i += 1; + } + } else if op == 18 { + let rep = reader.read_bits_usize(7)? + 11; + if i + rep > total_code_lengths { + return Err(Error::BadHuffmanDict); + } + for _ in 0..rep { + code_lengths[i] = 0; + i += 1; + } + } else { + // since we only assigned 0..=18 in the huffman table, + // we should never get a value outside that range. + debug_assert!(false, "should not be reachable"); } - let literal_length_code_lengths = &code_lengths[0..literal_length_code_lengths_count]; - let distance_code_lengths = - &code_lengths[literal_length_code_lengths_count..total_code_lengths]; - literal_length_table = HuffmanTable::from_code_lengths(literal_length_code_lengths); - distance_table = HuffmanTable::from_code_lengths(distance_code_lengths); - } else { - let mut lit = HuffmanTable::default(); - let mut dist = HuffmanTable::default(); - for i in 0..=143 { - lit.assign(0b00110000 + i, 8, i); - } - for i in 144..=255 { - lit.assign(0b110010000 + (i - 144), 9, i); - } - for i in 256..=279 { - lit.assign(i - 256, 7, i); - } - for i in 280..=287 { - lit.assign(0b11000000 + (i - 280), 8, i); - } - for i in 0..30 { - dist.assign(i, 5, i); + if i >= total_code_lengths { + break; } + } + let literal_length_code_lengths = &code_lengths[0..min(literal_length_code_lengths_count, 286)]; + let distance_code_lengths = &code_lengths[literal_length_code_lengths_count + ..min(total_code_lengths, literal_length_code_lengths_count + 30)]; + Ok(( + HuffmanTable::from_code_lengths(literal_length_code_lengths), + HuffmanTable::from_code_lengths(distance_code_lengths), + )) +} - literal_length_table = lit; - distance_table = dist; +fn get_fixed_huffman_dictionaries() -> (HuffmanTable, HuffmanTable) { + let mut lit = HuffmanTable::default(); + let mut dist = HuffmanTable::default(); + for i in 0..=143 { + lit.assign(0b00110000 + i, 8, i); + } + for i in 144..=255 { + lit.assign(0b110010000 + (i - 144), 9, i); + } + for i in 256..=279 { + lit.assign(i - 256, 7, i); } + for i in 280..=285 { + lit.assign(0b11000000 + (i - 280), 8, i); + } + for i in 0..30 { + dist.assign(i, 5, i); + } + (lit, dist) +} + +fn read_compressed_block( + reader: &mut BitReader<'_, R>, + writer: &mut DecompressedDataWriter, + dynamic: bool, +) -> Result<(), Error> { + let (literal_length_table, distance_table) = if dynamic { + read_dynamic_huffman_dictionary(reader)? + } else { + get_fixed_huffman_dictionaries() + }; + + fn parse_length( + reader: &mut BitReader<'_, R>, + literal_length: u16, + ) -> Result> { + Ok(match literal_length { + 257..=264 => literal_length - 254, + 265..=284 => { + const BASES: [u8; 20] = [ + 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, + 227, + ]; + let base: u16 = BASES[usize::from(literal_length - 265)].into(); + let extra_bits = (literal_length - 261) as u8 / 4; + let extra = reader.read_bits_u16(extra_bits)?; + base + extra + } + 285 => 258, + _ => unreachable!(), // we only could've assigned up to 285. + }) + } + + fn parse_distance( + reader: &mut BitReader<'_, R>, + distance_code: u16, + ) -> Result> { + Ok(match distance_code { + 0..=3 => distance_code + 1, + 4..=29 => { + const BASES: [u16; 26] = [ + 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, + 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, + ]; + let base = BASES[usize::from(distance_code - 4)]; + let extra_bits = (distance_code - 2) as u8 / 2; + let extra = reader.read_bits_u16(extra_bits)?; + base + extra + } + _ => unreachable!(), // we only could've assigned up to 29. + }) + } + loop { let literal_length = literal_length_table.read_value(reader)?; match literal_length { @@ -998,50 +1056,42 @@ fn read_compressed_block( // literal writer.write_byte(literal_length as u8)?; } - 256 => { - // end of block - break; - } - _ => { + 257.. => { // length + distance - let length = match literal_length { - 257..=264 => literal_length - 254, - 265..=284 => { - const BASES: [u8; 20] = [ - 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, - 163, 195, 227, - ]; - let base: u16 = BASES[usize::from(literal_length - 265)].into(); - let extra_bits = (literal_length - 261) as u8 / 4; - let extra = reader.read_bits_u16(extra_bits)?; - base + extra - } - 285 => 258, - _ => return Err(Error::BadCode), - }; - + let length = parse_length(reader, literal_length)?; let distance_code = distance_table.read_value(reader)?; - let distance = match distance_code { - 0..=3 => distance_code + 1, - 4..=29 => { - const BASES: [u16; 26] = [ - 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, - 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, - ]; - let base = BASES[usize::from(distance_code - 4)]; - let extra_bits = (distance_code - 2) as u8 / 2; - let extra = reader.read_bits_u16(extra_bits)?; - base + extra - } - _ => return Err(Error::BadCode), - }; + let distance = parse_distance(reader, distance_code)?; writer.copy(usize::from(distance), usize::from(length))?; } + 256 => { + // end of block + break; + } } } Ok(()) } +fn read_uncompressed_block( + reader: &mut BitReader<'_, R>, + writer: &mut DecompressedDataWriter, +) -> Result<(), Error> { + reader.bits >>= reader.bits_left % 8; + reader.bits_left -= reader.bits_left % 8; + let len = reader.read_bits_u16(16)?; + let nlen = reader.read_bits_u16(16)?; + if len ^ nlen != 0xffff { + return Err(Error::BadNlen); + } + let len: usize = len.into(); + if len > writer.slice.len() - writer.pos { + return Err(Error::TooMuchData); + } + reader.read_aligned_bytes(&mut writer.slice[writer.pos..writer.pos + len])?; + writer.pos += len; + Ok(()) +} + fn read_idat( reader: IdatReader<'_, R>, writer: &mut DecompressedDataWriter, @@ -1059,35 +1109,28 @@ fn read_idat( if compression_method != 8 || compression_info > 7 { return Err(Error::BadZlibHeader); } + // no preset dictionary if (flags & 0x100) != 0 { return Err(Error::BadZlibHeader); } let decompressed_size = reader.inner.header.decompressed_size(); - while writer.pos < decompressed_size { + loop { let bfinal = reader.read_bits(1)?; let btype = reader.read_bits(2)?; - if btype == 0 { - // uncompressed block - reader.bits >>= reader.bits_left % 8; - reader.bits_left -= reader.bits_left % 8; - let len = reader.read_bits_u16(16)?; - let nlen = reader.read_bits_u16(16)?; - if len ^ nlen != 0xffff { - return Err(Error::BadNlen); + match btype { + 0 => { + // uncompressed block + read_uncompressed_block(&mut reader, writer)?; } - let len: usize = len.into(); - if len > writer.slice.len() - writer.pos { - return Err(Error::TooMuchData); + 1 | 2 => { + // compressed block + read_compressed_block(&mut reader, writer, btype == 2)?; + } + _ => { + // 0b11 is not a valid block type + return Err(Error::BadBlockType); } - reader.read_aligned_bytes(&mut writer.slice[writer.pos..writer.pos + len])?; - writer.pos += len; - } else if btype == 1 || btype == 2 { - // compressed block - read_compressed_block(&mut reader, writer, btype == 2)?; - } else { - // 0b11 is not a valid block type - return Err(Error::BadBlockType); } if bfinal != 0 { break; @@ -1148,21 +1191,6 @@ fn apply_filters(header: &ImageHeader, data: &mut [u8]) -> Result<() const FILTER_AVG: u8 = 3; const FILTER_PAETH: u8 = 4; - #[inline] - fn paeth(a: u8, b: u8, c: u8) -> u8 { - let p = i32::from(a) + i32::from(b) - i32::from(c); - let pa = (p - i32::from(a)).abs(); - let pb = (p - i32::from(b)).abs(); - let pc = (p - i32::from(c)).abs(); - if pa <= pb && pa <= pc { - a - } else if pb <= pc { - b - } else { - c - } - } - s += 1; data.copy_within(s..s + scanline_bytes, d); match (filter, scanline == 0) { @@ -1195,19 +1223,30 @@ fn apply_filters(header: &ImageHeader, data: &mut [u8]) -> Result<() } (FILTER_PAETH, false) => { for i in d..d + x_byte_offset { - data[i] = data[i].wrapping_add(paeth(0, data[i - scanline_bytes], 0)); + data[i] = data[i].wrapping_add(data[i - scanline_bytes]); } for i in d + x_byte_offset..d + scanline_bytes { - data[i] = data[i].wrapping_add(paeth( - data[i - x_byte_offset], - data[i - scanline_bytes], - data[i - scanline_bytes - x_byte_offset], - )); + let a = data[i - x_byte_offset]; + let b = data[i - scanline_bytes]; + let c = data[i - scanline_bytes - x_byte_offset]; + + let p = i32::from(a) + i32::from(b) - i32::from(c); + let pa = (p - i32::from(a)).abs(); + let pb = (p - i32::from(b)).abs(); + let pc = (p - i32::from(c)).abs(); + let paeth = if pa <= pb && pa <= pc { + a + } else if pb <= pc { + b + } else { + c + }; + data[i] = data[i].wrapping_add(paeth); } } (FILTER_PAETH, true) => { for i in d + x_byte_offset..d + scanline_bytes { - data[i] = data[i].wrapping_add(paeth(data[i - x_byte_offset], 0, 0)); + data[i] = data[i].wrapping_add(data[i - x_byte_offset]); } } (5.., _) => return Err(Error::BadFilter), @@ -1254,8 +1293,7 @@ fn read_non_idat_chunks( for i in 0..count { palette[i][0..3].copy_from_slice(&data[3 * i..3 * i + 3]); } - // checksum - reader.skip_bytes(4)?; + reader.skip_bytes(4)?; // CRC } else if &chunk_type == b"tRNS" && header.color_type == ColorType::Indexed { if chunk_len > 256 { return Err(Error::BadTrnsChunk); @@ -1265,8 +1303,7 @@ fn read_non_idat_chunks( for i in 0..chunk_len { palette[i][3] = data[i]; } - // checksum - reader.skip_bytes(4)?; + reader.skip_bytes(4)?; // CRC } else if (chunk_type[0] & 0x20) != 0 || &chunk_type == b"PLTE" { // non-essential chunk reader.skip_bytes(chunk_len + 4)?; @@ -1322,8 +1359,8 @@ pub fn read_png<'a, R: Read>( palette[1] = [255, 255, 255, 255]; } BitDepth::Two => { - #[allow(clippy::needless_range_loop)] // clippy's suggestion here is more unreadable imo + #[allow(clippy::needless_range_loop)] for i in 0..4 { let v = (255 * i / 3) as u8; palette[i] = [v, v, v, 255]; -- cgit v1.2.3