From 04d3faed5edf3bec45ad6ca21a12b6ae446cc878 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1niel=20Buga?= Date: Mon, 18 Aug 2025 20:16:25 +0200 Subject: [PATCH] Rewrite SHA finish to be nonblocking (#3948) --- esp-hal/CHANGELOG.md | 1 + esp-hal/src/reg_access.rs | 3 +- esp-hal/src/sha.rs | 218 ++++++++++++++++++++++++-------------- hil-test/tests/sha.rs | 2 +- 4 files changed, 143 insertions(+), 81 deletions(-) diff --git a/esp-hal/CHANGELOG.md b/esp-hal/CHANGELOG.md index 441430f8c..f13e27efd 100644 --- a/esp-hal/CHANGELOG.md +++ b/esp-hal/CHANGELOG.md @@ -40,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `DmaTxBuffer::from_view` and `DmaRxBuffer::from_view` now return an object with type `DmaTx/RxBuffer::Final`. (#3923) - `i2c::master::Config::timeout` has been de-stabilized, and `i2c::master::Config::software_timeout`. (#3926) - The default values of `i2c::master::Config` timeouts have been changed to their maximum possible values. (#3926) +- `ShaDigest::finish` has been reimplemented to be properly non-blocking (#3948) ### Fixed diff --git a/esp-hal/src/reg_access.rs b/esp-hal/src/reg_access.rs index 3a7d4a7d5..1240ec9b3 100644 --- a/esp-hal/src/reg_access.rs +++ b/esp-hal/src/reg_access.rs @@ -101,6 +101,7 @@ impl AlignmentHelper { .write_volatile(E::u32_from_bytes(self.buf)); } + // We return the **extra** bytes appended besides those already written into the buffer. let ret = U32_ALIGN_SIZE - self.buf_fill; self.buf_fill = 0; @@ -113,7 +114,7 @@ impl AlignmentHelper { // This function is similar to `volatile_set_memory` but will prepend data that // was previously ingested and ensure aligned (u32) writes. pub fn volatile_write(&mut self, dst_ptr: *mut u32, val: u8, count: usize, offset: usize) { - let count = count / U32_ALIGN_SIZE; + let count = count.div_ceil(U32_ALIGN_SIZE); let offset = offset / U32_ALIGN_SIZE; let dst_ptr = unsafe { dst_ptr.add(offset) }; diff --git a/esp-hal/src/sha.rs b/esp-hal/src/sha.rs index 4b2d4b073..020bd85dd 100644 --- a/esp-hal/src/sha.rs +++ b/esp-hal/src/sha.rs @@ -143,105 +143,141 @@ impl<'d> Sha<'d> { mod_cursor, ); - state.cursor = state.cursor.wrapping_add(incoming.len() - remaining.len()); + state.cursor += incoming.len() - remaining.len(); if bound_reached { - // Message is full now. - - if self.is_busy(state.algorithm) { - // The message buffer is full and the hardware is still processing the previous - // message. There's nothing to be done besides wait for the hardware. - state.message_buffer_is_full = true; - } else { - // Send the full buffer. - self.process_buffer(state); - } + // Message is full now. We don't have to wait for the result, just start the processing + // or set the flag. + _ = self.process_buffer_or_wait(state); } Ok(remaining) } + fn process_buffer_or_wait(&self, state: &mut DigestState) -> nb::Result<(), Infallible> { + if self.is_busy(state.algorithm) { + // The message buffer is full and the hardware is still processing the + // previous message. There's nothing to be done besides wait for the + // hardware. + state.message_buffer_is_full = true; + return Err(nb::Error::WouldBlock); + } + + // Send the full buffer. + self.process_buffer(state); + + Ok(()) + } + fn finish(&self, state: &mut DigestState, output: &mut [u8]) -> nb::Result<(), Infallible> { - // Store message length for padding - let length = (state.cursor as u64 * 8).to_be_bytes(); - nb::block!(self.update(state, &[0x80]))?; // Append "1" bit - - let chunk_len = state.algorithm.chunk_length(); - - // Flush partial data, ensures aligned cursor - { - while self.is_busy(state.algorithm) {} - if state.message_buffer_is_full { - self.process_buffer(state); - - state.message_buffer_is_full = false; - while self.is_busy(state.algorithm) {} + if state.message_buffer_is_full { + // Wait for the hardware to become idle. + if self.is_busy(state.algorithm) { + return Err(nb::Error::WouldBlock); } + // Start processing so that we can continue writing into SHA memory. + self.process_buffer(state); + state.message_buffer_is_full = false; + } + + let chunk_len = state.algorithm.chunk_length(); + if state.finalize_state == FinalizeState::NotStarted { + let cursor = state.cursor; + self.update(state, &[0x80])?; // Append "1" bit + state.finished_message_size = cursor; + + state.finalize_state = FinalizeState::FlushAlignBuffer; + } + + if state.finalize_state == FinalizeState::FlushAlignBuffer { let flushed = state .alignment_helper .flush_to(m_mem(&self.sha, 0), state.cursor % chunk_len); - state.cursor = state.cursor.wrapping_add(flushed); - if flushed > 0 && state.cursor.is_multiple_of(chunk_len) { - self.process_buffer(state); - while self.is_busy(state.algorithm) {} + state.finalize_state = FinalizeState::ZeroPadAlmostFull; + if flushed > 0 { + state.cursor += flushed; + if state.cursor.is_multiple_of(chunk_len) { + self.process_buffer_or_wait(state)?; + } } } - debug_assert!(state.cursor.is_multiple_of(4)); let mut mod_cursor = state.cursor % chunk_len; - if (chunk_len - mod_cursor) < chunk_len / 8 { + if state.finalize_state == FinalizeState::ZeroPadAlmostFull { // Zero out remaining data if buffer is almost full (>=448/896), and process - // buffer + // buffer. + // + // In either case, we'll continue to the next state. + state.finalize_state = FinalizeState::WriteMessageLength; let pad_len = chunk_len - mod_cursor; + if pad_len < state.algorithm.message_length_bytes() { + state.alignment_helper.volatile_write( + m_mem(&self.sha, 0), + 0_u8, + pad_len, + mod_cursor, + ); + state.cursor += pad_len; + + self.process_buffer_or_wait(state)?; + mod_cursor = 0; + } + } + + if state.finalize_state == FinalizeState::WriteMessageLength { + // In this state, we pad the remainder of the message block with 0s and append the + // message length to the very end. + // FIXME: this u64 should be u128 for 1024-bit block algos. Since cursor is only usize + // (u32), this makes no difference currently, but may limit maximum message length in + // the future. + let message_len_bytes = size_of::(); + + let pad_len = chunk_len - mod_cursor - message_len_bytes; + // Fill remaining space with zeros state .alignment_helper - .volatile_write(m_mem(&self.sha, 0), 0_u8, pad_len, mod_cursor); - self.process_buffer(state); - state.cursor = state.cursor.wrapping_add(pad_len); + .volatile_write(m_mem(&self.sha, 0), 0, pad_len, mod_cursor); - debug_assert_eq!(state.cursor % chunk_len, 0); - mod_cursor = 0; + // Write message length + let length = state.finished_message_size as u64 * 8; + state.alignment_helper.aligned_volatile_copy( + m_mem(&self.sha, 0), + &length.to_be_bytes(), + chunk_len, + chunk_len - message_len_bytes, + ); - // Spin-wait for finish - while self.is_busy(state.algorithm) {} + // Set up last state, start processing + state.finalize_state = FinalizeState::ReadResult; + self.process_buffer_or_wait(state)?; } - let pad_len = chunk_len - mod_cursor - size_of::(); + if state.finalize_state == FinalizeState::ReadResult { + if state.algorithm.is_busy(&self.sha) { + return Err(nb::Error::WouldBlock); + } + if state.algorithm.load(&self.sha) { + // Spin wait for result, 8-20 clock cycles according to manual + while self.is_busy(state.algorithm) {} + } - state - .alignment_helper - .volatile_write(m_mem(&self.sha, 0), 0, pad_len, mod_cursor); + state.alignment_helper.volatile_read_regset( + h_mem(&self.sha, 0), + output, + core::cmp::min(output.len(), 32), + ); - state.alignment_helper.aligned_volatile_copy( - m_mem(&self.sha, 0), - &length, - chunk_len, - chunk_len - size_of::(), - ); + state.first_run = true; + state.cursor = 0; + state.alignment_helper.reset(); + state.finalize_state = FinalizeState::NotStarted; - self.process_buffer(state); - // Spin-wait for final buffer to be processed - while self.is_busy(state.algorithm) {} - - if state.algorithm.load(&self.sha) { - // Spin wait for result, 8-20 clock cycles according to manual - while self.is_busy(state.algorithm) {} + return Ok(()); } - state.alignment_helper.volatile_read_regset( - h_mem(&self.sha, 0), - output, - core::cmp::min(output.len(), 32), - ); - - state.first_run = true; - state.finished = true; - state.cursor = 0; - state.alignment_helper.reset(); - - Ok(()) + Err(nb::Error::WouldBlock) } fn update<'a>( @@ -249,8 +285,7 @@ impl<'d> Sha<'d> { state: &mut DigestState, incoming: &'a [u8], ) -> nb::Result<&'a [u8], Infallible> { - state.finished = false; - + state.finalize_state = FinalizeState::default(); self.write_data(state, incoming) } } @@ -283,14 +318,25 @@ pub struct ShaDigest<'d, A, S: Borrow>> { phantom: PhantomData<(&'d (), A)>, } +#[derive(Clone, Copy, Debug, PartialEq, Default)] +enum FinalizeState { + #[default] + NotStarted, + FlushAlignBuffer, + ZeroPadAlmostFull, + WriteMessageLength, + ReadResult, +} + #[derive(Clone, Debug)] struct DigestState { algorithm: ShaAlgorithmKind, alignment_helper: AlignmentHelper, cursor: usize, first_run: bool, - finished: bool, + finished_message_size: usize, message_buffer_is_full: bool, + finalize_state: FinalizeState, } impl DigestState { @@ -300,8 +346,9 @@ impl DigestState { alignment_helper: AlignmentHelper::default(), cursor: 0, first_run: true, - finished: false, + finished_message_size: 0, message_buffer_is_full: false, + finalize_state: FinalizeState::default(), } } } @@ -440,13 +487,6 @@ impl Context { pub fn first_run(&self) -> bool { self.state.first_run } - - /// Indicates if the SHA context has finished processing the data. - /// - /// Returns `true` if the SHA calculation is complete, otherwise returns. - pub fn finished(&self) -> bool { - self.state.finished - } } #[cfg(not(esp32))] @@ -620,6 +660,26 @@ impl ShaAlgorithmKind { } } + /// Bytes needed to represent the length of the longest possible message. + const fn message_length_bytes(self) -> usize { + match self { + #[cfg(sha_algo_sha_1)] + ShaAlgorithmKind::Sha1 => 8, + #[cfg(sha_algo_sha_224)] + ShaAlgorithmKind::Sha224 => 8, + #[cfg(sha_algo_sha_256)] + ShaAlgorithmKind::Sha256 => 8, + #[cfg(sha_algo_sha_384)] + ShaAlgorithmKind::Sha384 => 16, + #[cfg(sha_algo_sha_512)] + ShaAlgorithmKind::Sha512 => 16, + #[cfg(sha_algo_sha_512_224)] + ShaAlgorithmKind::Sha512_224 => 16, + #[cfg(sha_algo_sha_512_256)] + ShaAlgorithmKind::Sha512_256 => 16, + } + } + fn start(self, sha: &crate::peripherals::SHA<'_>) { let regs = sha.register_block(); cfg_if::cfg_if! { diff --git a/hil-test/tests/sha.rs b/hil-test/tests/sha.rs index ce91cc9cd..528905848 100644 --- a/hil-test/tests/sha.rs +++ b/hil-test/tests/sha.rs @@ -32,7 +32,7 @@ fn assert_sw_hash(input: &[u8], expected_output: &[u8]) { hasher.update(input); let soft_result = hasher.finalize(); - defmt::assert_eq!(expected_output, &soft_result[..]); + hil_test::assert_eq!(expected_output, &soft_result[..]); } fn hash_sha(sha: &mut Sha<'static>, mut input: &[u8], output: &mut [u8]) {