Test and fix async RSA (#2002)

* RSA cleanup & API consistency change, part 2

* RSA cleanup & API consistency change, part 3

* Add async tests

* Fix async for ESP32

* Merge impl blocks

* Backtrack on some mutability changes

* Use Acquire/Release ordering

* Fwd to write_multi_start instead of duplicating impl
This commit is contained in:
Dániel Buga 2024-08-28 11:56:05 +02:00 committed by GitHub
parent c7a7760b51
commit 6abbc72e11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 425 additions and 512 deletions

View File

@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Remove `fn free(self)` in HMAC which goes against esp-hal API guidelines (#1972)
- PARL_IO use ReadBuffer and WriteBuffer for Async DMA (#1996)
- `AnyPin`, `AnyInputOnyPin` and `DummyPin` are now accessible from `gpio` module (#1918)
- Changed the RSA modular multiplication API to be consistent across devices (#2002)
### Fixed
@ -45,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Reset peripherals in driver constructors where missing (#1893, #1961)
- Fixed ESP32-S2 systimer interrupts (#1979)
- Software interrupt 3 is no longer available when it is required by `esp-hal-embassy`. (#2011)
- ESP32: Fixed async RSA (#2002)
### Removed

View File

@ -1,8 +1,4 @@
use core::{
convert::Infallible,
marker::PhantomData,
ptr::{copy_nonoverlapping, write_bytes},
};
use core::convert::Infallible;
use crate::rsa::{
implement_op,
@ -37,35 +33,30 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> {
}
/// Starts the modular exponentiation operation.
pub(super) fn write_modexp_start(&mut self) {
pub(super) fn write_modexp_start(&self) {
self.rsa
.modexp_start()
.write(|w| w.modexp_start().set_bit());
}
/// Starts the multiplication operation.
pub(super) fn write_multi_start(&mut self) {
pub(super) fn write_multi_start(&self) {
self.rsa.mult_start().write(|w| w.mult_start().set_bit());
}
/// Starts the modular multiplication operation.
pub(super) fn write_modmulti_start(&self) {
self.write_multi_start();
}
/// Clears the RSA interrupt flag.
pub(super) fn clear_interrupt(&mut self) {
self.rsa.interrupt().write(|w| w.interrupt().set_bit());
}
/// Checks if the RSA peripheral is idle.
pub(super) fn is_idle(&mut self) -> bool {
self.rsa.interrupt().read().bits() == 1
}
unsafe fn write_multi_operand_a<const N: usize>(&mut self, operand_a: &[u32; N]) {
copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N);
write_bytes(self.rsa.x_mem(0).as_ptr().add(N), 0, N);
}
unsafe fn write_multi_operand_b<const N: usize>(&mut self, operand_b: &[u32; N]) {
write_bytes(self.rsa.z_mem(0).as_ptr(), 0, N);
copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N);
pub(super) fn is_idle(&self) -> bool {
self.rsa.interrupt().read().interrupt().bit_is_set()
}
}
@ -92,59 +83,18 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplicati
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaMultiplication`.
///
/// `m_prime` can be calculated using `-(modular multiplicative inverse of
/// modulus) mod 2^32`.
///
/// For more information refer to 24.3.2 of <https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_en.pdf>.
pub fn new(rsa: &'a mut Rsa<'d, DM>, modulus: &T::InputType, m_prime: u32) -> Self {
Self::set_mode(rsa);
unsafe {
rsa.write_modulus(modulus);
}
rsa.write_mprime(m_prime);
Self {
rsa,
phantom: PhantomData,
}
}
fn set_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_multi_mode((N / 16 - 1) as u32)
}
/// Starts the first step of modular multiplication operation.
///
/// `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`.
/// Starts the modular multiplication operation.
///
/// For more information refer to 24.3.2 of <https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_en.pdf>.
pub fn start_step1(&mut self, operand_a: &T::InputType, r: &T::InputType) {
unsafe {
self.rsa.write_operand_a(operand_a);
self.rsa.write_r(r);
}
self.start();
}
/// Starts the second step of modular multiplication operation.
///
/// This is a non blocking function that returns without an error if
/// operation is completed successfully. `start_step1` must be called
/// before calling this function.
pub fn start_step2(&mut self, operand_b: &T::InputType) {
while !self.rsa.is_idle() {}
self.rsa.clear_interrupt();
unsafe {
self.rsa.write_operand_a(operand_b);
}
self.start();
}
fn start(&mut self) {
pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) {
self.rsa.write_multi_start();
self.rsa.wait_for_idle();
self.rsa.write_operand_a(operand_b);
}
}
@ -152,70 +102,22 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaModularExponentiation`.
///
/// `m_prime` can be calculated using `-(modular multiplicative inverse of
/// modulus) mod 2^32`.
///
/// For more information refer to 24.3.2 of <https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_en.pdf>.
pub fn new(
rsa: &'a mut Rsa<'d, DM>,
exponent: &T::InputType,
modulus: &T::InputType,
m_prime: u32,
) -> Self {
Self::set_mode(rsa);
unsafe {
rsa.write_operand_b(exponent);
rsa.write_modulus(modulus);
}
rsa.write_mprime(m_prime);
Self {
rsa,
phantom: PhantomData,
}
}
/// Sets the modular exponentiation mode for the RSA hardware.
pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_modexp_mode((N / 16 - 1) as u32)
}
/// Starts the modular exponentiation operation on the RSA hardware.
pub(super) fn start(&mut self) {
self.rsa.write_modexp_start();
}
}
impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplication<'a, 'd, T, DM>
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaMultiplication`.
pub fn new(rsa: &'a mut Rsa<'d, DM>) -> Self {
Self::set_mode(rsa);
Self {
rsa,
phantom: PhantomData,
}
}
/// Starts the multiplication operation.
pub fn start_multiplication(&mut self, operand_a: &T::InputType, operand_b: &T::InputType) {
unsafe {
self.rsa.write_multi_operand_a(operand_a);
self.rsa.write_multi_operand_b(operand_b);
}
self.start();
}
/// Sets the multiplication mode for the RSA hardware.
pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_multi_mode(((N * 2) / 16 + 7) as u32)
}
/// Starts the multiplication operation on the RSA hardware.
pub(super) fn start(&mut self) {
self.rsa.write_multi_start();
pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) {
self.rsa.write_multi_operand_b(operand_b);
}
}

View File

@ -1,4 +1,4 @@
use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping};
use core::convert::Infallible;
use crate::rsa::{
implement_op,
@ -94,21 +94,21 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> {
}
/// Starts the modular exponentiation operation.
pub(super) fn write_modexp_start(&mut self) {
pub(super) fn write_modexp_start(&self) {
self.rsa
.set_start_modexp()
.write(|w| w.set_start_modexp().set_bit());
}
/// Starts the multiplication operation.
pub(super) fn write_multi_start(&mut self) {
pub(super) fn write_multi_start(&self) {
self.rsa
.set_start_mult()
.write(|w| w.set_start_mult().set_bit());
}
/// Starts the modular multiplication operation.
fn write_modmulti_start(&mut self) {
pub(super) fn write_modmulti_start(&self) {
self.rsa
.set_start_modmult()
.write(|w| w.set_start_modmult().set_bit());
@ -120,13 +120,9 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> {
}
/// Checks if the RSA peripheral is idle.
pub(super) fn is_idle(&mut self) -> bool {
pub(super) fn is_idle(&self) -> bool {
self.rsa.query_idle().read().query_idle().bit_is_set()
}
unsafe fn write_multi_operand_b<const N: usize>(&mut self, operand_b: &[u32; N]) {
copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N);
}
}
/// Module defining marker types for various RSA operand sizes.
@ -240,34 +236,7 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaModularExponentiation`.
///
/// `m_prime` could be calculated using `-(modular multiplicative inverse of
/// modulus) mod 2^32`.
///
/// For more information refer to 19.3.1 of <https://www.espressif.com/sites/default/files/documentation/esp32-c3_technical_reference_manual_en.pdf>.
pub fn new(
rsa: &'a mut Rsa<'d, DM>,
exponent: &T::InputType,
modulus: &T::InputType,
m_prime: u32,
) -> Self {
Self::set_mode(rsa);
unsafe {
rsa.write_operand_b(exponent);
rsa.write_modulus(modulus);
}
rsa.write_mprime(m_prime);
if rsa.is_search_enabled() {
rsa.write_search_position(Self::find_search_pos(exponent));
}
Self {
rsa,
phantom: PhantomData,
}
}
fn find_search_pos(exponent: &T::InputType) -> u32 {
pub(super) fn find_search_pos(exponent: &T::InputType) -> u32 {
for (i, byte) in exponent.iter().rev().enumerate() {
if *byte == 0 {
continue;
@ -278,64 +247,21 @@ where
}
/// Sets the modular exponentiation mode for the RSA hardware.
pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_mode((N - 1) as u32)
}
/// Starts the modular exponentiation operation on the RSA hardware.
pub(super) fn start(&mut self) {
self.rsa.write_modexp_start();
}
}
impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplication<'a, 'd, T, DM>
where
T: RsaMode<InputType = [u32; N]>,
{
fn write_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_mode((N - 1) as u32)
}
/// Creates an instance of `RsaModularMultiplication`.
///
/// `m_prime` can be calculated using `-(modular multiplicative inverse of
/// modulus) mod 2^32`.
///
/// For more information refer to 19.3.1 of <https://www.espressif.com/sites/default/files/documentation/esp32-c3_technical_reference_manual_en.pdf>.
pub fn new(
rsa: &'a mut Rsa<'d, DM>,
operand_a: &T::InputType,
operand_b: &T::InputType,
modulus: &T::InputType,
m_prime: u32,
) -> Self {
Self::write_mode(rsa);
rsa.write_mprime(m_prime);
unsafe {
rsa.write_modulus(modulus);
rsa.write_operand_a(operand_a);
rsa.write_operand_b(operand_b);
}
Self {
rsa,
phantom: PhantomData,
}
}
/// Starts the modular multiplication operation.
///
/// `r` could be calculated using `2 ^ ( bitlength * 2 ) mod modulus`.
///
/// For more information refer to 19.3.1 of <https://www.espressif.com/sites/default/files/documentation/esp32-c3_technical_reference_manual_en.pdf>.
pub fn start_modular_multiplication(&mut self, r: &T::InputType) {
unsafe {
self.rsa.write_r(r);
}
self.start();
}
fn start(&mut self) {
self.rsa.write_modmulti_start();
pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) {
self.rsa.write_operand_b(operand_b);
}
}
@ -343,33 +269,12 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaMultiplication`.
pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self {
Self::set_mode(rsa);
unsafe {
rsa.write_operand_a(operand_a);
}
Self {
rsa,
phantom: PhantomData,
}
}
/// Starts the multiplication operation.
pub fn start_multiplication(&mut self, operand_b: &T::InputType) {
unsafe {
self.rsa.write_multi_operand_b(operand_b);
}
self.start();
pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) {
self.rsa.write_multi_operand_b(operand_b);
}
/// Sets the multiplication mode for the RSA hardware.
pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_mode((N * 2 - 1) as u32)
}
/// Starts the multiplication operation on the RSA hardware.
pub(super) fn start(&mut self) {
self.rsa.write_multi_start();
}
}

View File

@ -1,4 +1,4 @@
use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping};
use core::convert::Infallible;
use crate::rsa::{
implement_op,
@ -101,19 +101,19 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> {
}
/// Starts the modular exponentiation operation.
pub(super) fn write_modexp_start(&mut self) {
pub(super) fn write_modexp_start(&self) {
self.rsa
.modexp_start()
.write(|w| w.modexp_start().set_bit());
}
/// Starts the multiplication operation.
pub(super) fn write_multi_start(&mut self) {
pub(super) fn write_multi_start(&self) {
self.rsa.mult_start().write(|w| w.mult_start().set_bit());
}
/// Starts the modular multiplication operation.
fn write_modmulti_start(&mut self) {
pub(super) fn write_modmulti_start(&self) {
self.rsa
.modmult_start()
.write(|w| w.modmult_start().set_bit());
@ -127,13 +127,9 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> {
}
/// Checks if the RSA peripheral is idle.
pub(super) fn is_idle(&mut self) -> bool {
pub(super) fn is_idle(&self) -> bool {
self.rsa.idle().read().idle().bit_is_set()
}
unsafe fn write_multi_operand_b<const N: usize>(&mut self, operand_b: &[u32; N]) {
copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N);
}
}
pub mod operand_sizes {
@ -281,34 +277,7 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaModularExponentiation`.
///
/// `m_prime` can be calculated using `-(modular multiplicative inverse of
/// modulus) mod 2^32`.
///
/// For more information refer to 20.3.1 of <https://www.espressif.com/sites/default/files/documentation/esp32-s3_technical_reference_manual_en.pdf>.
pub fn new(
rsa: &'a mut Rsa<'d, DM>,
exponent: &T::InputType,
modulus: &T::InputType,
m_prime: u32,
) -> Self {
Self::set_mode(rsa);
unsafe {
rsa.write_operand_b(exponent);
rsa.write_modulus(modulus);
}
rsa.write_mprime(m_prime);
if rsa.is_search_enabled() {
rsa.write_search_position(Self::find_search_pos(exponent));
}
Self {
rsa,
phantom: PhantomData,
}
}
fn find_search_pos(exponent: &T::InputType) -> u32 {
pub(super) fn find_search_pos(exponent: &T::InputType) -> u32 {
for (i, byte) in exponent.iter().rev().enumerate() {
if *byte == 0 {
continue;
@ -319,64 +288,21 @@ where
}
/// Sets the modular exponentiation mode for the RSA hardware.
pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_mode((N - 1) as u32)
}
/// Starts the modular exponentiation operation on the RSA hardware.
pub(super) fn start(&mut self) {
self.rsa.write_modexp_start();
}
}
impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplication<'a, 'd, T, DM>
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaModularMultiplication`.
///
/// `m_prime` could be calculated using `-(modular multiplicative inverse of
/// modulus) mod 2^32`.
///
/// For more information refer to 20.3.1 of <https://www.espressif.com/sites/default/files/documentation/esp32-s3_technical_reference_manual_en.pdf>.
pub fn new(
rsa: &'a mut Rsa<'d, DM>,
operand_a: &T::InputType,
operand_b: &T::InputType,
modulus: &T::InputType,
m_prime: u32,
) -> Self {
Self::write_mode(rsa);
rsa.write_mprime(m_prime);
unsafe {
rsa.write_modulus(modulus);
rsa.write_operand_a(operand_a);
rsa.write_operand_b(operand_b);
}
Self {
rsa,
phantom: PhantomData,
}
}
fn write_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_mode((N - 1) as u32)
}
/// Starts the modular multiplication operation.
///
/// `r` could be calculated using `2 ^ ( bitlength * 2 ) mod modulus`.
///
/// For more information refer to 19.3.1 of <https://www.espressif.com/sites/default/files/documentation/esp32-s3_technical_reference_manual_en.pdf>.
pub fn start_modular_multiplication(&mut self, r: &T::InputType) {
unsafe {
self.rsa.write_r(r);
}
self.start();
}
fn start(&mut self) {
self.rsa.write_modmulti_start();
pub(super) fn set_up_modular_multiplication(&mut self, operand_b: &T::InputType) {
self.rsa.write_operand_b(operand_b);
}
}
@ -384,33 +310,12 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaMultiplication`.
pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self {
Self::set_mode(rsa);
unsafe {
rsa.write_operand_a(operand_a);
}
Self {
rsa,
phantom: PhantomData,
}
}
/// Starts the multiplication operation.
pub fn start_multiplication(&mut self, operand_b: &T::InputType) {
unsafe {
self.rsa.write_multi_operand_b(operand_b);
}
self.start();
}
/// Sets the multiplication mode for the RSA hardware.
pub(super) fn set_mode(rsa: &mut Rsa<'d, DM>) {
pub(super) fn write_mode(rsa: &mut Rsa<'d, DM>) {
rsa.write_mode((N * 2 - 1) as u32)
}
/// Starts the multiplication operation on the RSA hardware.
pub(super) fn start(&mut self) {
self.rsa.write_multi_start();
pub(super) fn set_up_multiplication(&mut self, operand_b: &T::InputType) {
self.rsa.write_multi_operand_b(operand_b);
}
}

View File

@ -16,16 +16,10 @@
//! ## Examples
//!
//! ### Modular Exponentiation, Modular Multiplication, and Multiplication
//! Visit the [RSA test] for an example of using the peripheral.
//!
//! ## Implementation State
//!
//! - The [nb] crate is used to handle non-blocking operations.
//! - This peripheral supports `async` on every available chip except of `esp32`
//! (to be solved).
//! Visit the [RSA test suite] for an example of using the peripheral.
//!
//! [nb]: https://docs.rs/nb/1.1.0/nb/
//! [RSA test]: https://github.com/esp-rs/esp-hal/blob/main/hil-test/tests/rsa.rs
//! [RSA test suite]: https://github.com/esp-rs/esp-hal/blob/main/hil-test/tests/rsa.rs
use core::{marker::PhantomData, ptr::copy_nonoverlapping};
@ -53,24 +47,6 @@ pub struct Rsa<'d, DM: crate::Mode> {
phantom: PhantomData<DM>,
}
impl<'d, DM: crate::Mode> Rsa<'d, DM> {
fn internal_set_interrupt_handler(&mut self, handler: InterruptHandler) {
unsafe {
crate::interrupt::bind_interrupt(crate::peripherals::Interrupt::RSA, handler.handler());
crate::interrupt::enable(crate::peripherals::Interrupt::RSA, handler.priority())
.unwrap();
}
}
fn read_results<const N: usize>(&mut self, outbuf: &mut [u32; N]) {
while !self.is_idle() {}
unsafe {
self.read_out(outbuf);
}
self.clear_interrupt();
}
}
impl<'d> Rsa<'d, crate::Blocking> {
/// Create a new instance in [crate::Blocking] mode.
///
@ -111,32 +87,66 @@ impl<'d, DM: crate::Mode> Rsa<'d, DM> {
}
}
unsafe fn write_operand_b<const N: usize>(&mut self, operand_b: &[u32; N]) {
copy_nonoverlapping(operand_b.as_ptr(), self.rsa.y_mem(0).as_ptr(), N);
fn write_operand_b<const N: usize>(&mut self, operand_b: &[u32; N]) {
unsafe {
copy_nonoverlapping(operand_b.as_ptr(), self.rsa.y_mem(0).as_ptr(), N);
}
}
unsafe fn write_modulus<const N: usize>(&mut self, modulus: &[u32; N]) {
copy_nonoverlapping(modulus.as_ptr(), self.rsa.m_mem(0).as_ptr(), N);
fn write_modulus<const N: usize>(&mut self, modulus: &[u32; N]) {
unsafe {
copy_nonoverlapping(modulus.as_ptr(), self.rsa.m_mem(0).as_ptr(), N);
}
}
fn write_mprime(&mut self, m_prime: u32) {
self.rsa.m_prime().write(|w| unsafe { w.bits(m_prime) });
}
unsafe fn write_operand_a<const N: usize>(&mut self, operand_a: &[u32; N]) {
copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N);
fn write_operand_a<const N: usize>(&mut self, operand_a: &[u32; N]) {
unsafe {
copy_nonoverlapping(operand_a.as_ptr(), self.rsa.x_mem(0).as_ptr(), N);
}
}
unsafe fn write_r<const N: usize>(&mut self, r: &[u32; N]) {
copy_nonoverlapping(r.as_ptr(), self.rsa.z_mem(0).as_ptr(), N);
fn write_multi_operand_b<const N: usize>(&mut self, operand_b: &[u32; N]) {
unsafe {
copy_nonoverlapping(operand_b.as_ptr(), self.rsa.z_mem(0).as_ptr().add(N), N);
}
}
unsafe fn read_out<const N: usize>(&mut self, outbuf: &mut [u32; N]) {
copy_nonoverlapping(
self.rsa.z_mem(0).as_ptr() as *const u32,
outbuf.as_ptr() as *mut u32,
N,
);
fn write_r<const N: usize>(&mut self, r: &[u32; N]) {
unsafe {
copy_nonoverlapping(r.as_ptr(), self.rsa.z_mem(0).as_ptr(), N);
}
}
fn read_out<const N: usize>(&self, outbuf: &mut [u32; N]) {
unsafe {
copy_nonoverlapping(
self.rsa.z_mem(0).as_ptr() as *const u32,
outbuf.as_ptr() as *mut u32,
N,
);
}
}
fn internal_set_interrupt_handler(&mut self, handler: InterruptHandler) {
unsafe {
crate::interrupt::bind_interrupt(crate::peripherals::Interrupt::RSA, handler.handler());
crate::interrupt::enable(crate::peripherals::Interrupt::RSA, handler.priority())
.unwrap();
}
}
fn wait_for_idle(&mut self) {
while !self.is_idle() {}
self.clear_interrupt();
}
fn read_results<const N: usize>(&mut self, outbuf: &mut [u32; N]) {
self.wait_for_idle();
self.read_out(outbuf);
}
}
@ -155,7 +165,7 @@ pub trait Multi: RsaMode {
macro_rules! implement_op {
(($x:literal, multi)) => {
paste! {
/// Represents an RSA operation for the given bit size with multi-output.
#[doc = concat!($x, "-bit RSA operation.")]
pub struct [<Op $x>];
impl Multi for [<Op $x>] {
@ -204,17 +214,47 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularExponentiati
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaModularExponentiation`.
///
/// `m_prime` could be calculated using `-(modular multiplicative inverse of
/// modulus) mod 2^32`.
///
/// For more information refer to 24.3.2 of <https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_en.pdf>.
pub fn new(
rsa: &'a mut Rsa<'d, DM>,
exponent: &T::InputType,
modulus: &T::InputType,
m_prime: u32,
) -> Self {
Self::write_mode(rsa);
rsa.write_operand_b(exponent);
rsa.write_modulus(modulus);
rsa.write_mprime(m_prime);
#[cfg(not(esp32))]
if rsa.is_search_enabled() {
rsa.write_search_position(Self::find_search_pos(exponent));
}
Self {
rsa,
phantom: PhantomData,
}
}
fn set_up_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
self.rsa.write_operand_a(base);
self.rsa.write_r(r);
}
/// Starts the modular exponentiation operation.
///
/// `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`.
///
/// For more information refer to 24.3.2 of <https://www.espressif.com/sites/default/files/documentation/esp32_technical_reference_manual_en.pdf>.
pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) {
unsafe {
self.rsa.write_operand_a(base);
self.rsa.write_r(r);
}
self.start();
self.set_up_exponentiation(base, r);
self.rsa.write_modexp_start();
}
/// Reads the result to the given buffer.
@ -240,6 +280,40 @@ impl<'a, 'd, T: RsaMode, DM: crate::Mode, const N: usize> RsaModularMultiplicati
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaModularMultiplication`.
///
/// - `r` can be calculated using `2 ^ ( bitlength * 2 ) mod modulus`.
/// - `m_prime` can be calculated using `-(modular multiplicative inverse of
/// modulus) mod 2^32`.
///
/// For more information refer to 20.3.1 of <https://www.espressif.com/sites/default/files/documentation/esp32-s3_technical_reference_manual_en.pdf>.
pub fn new(
rsa: &'a mut Rsa<'d, DM>,
operand_a: &T::InputType,
modulus: &T::InputType,
r: &T::InputType,
m_prime: u32,
) -> Self {
Self::write_mode(rsa);
rsa.write_mprime(m_prime);
rsa.write_modulus(modulus);
rsa.write_operand_a(operand_a);
rsa.write_r(r);
Self {
rsa,
phantom: PhantomData,
}
}
/// Starts the modular multiplication operation.
///
/// For more information refer to 19.3.1 of <https://www.espressif.com/sites/default/files/documentation/esp32-c3_technical_reference_manual_en.pdf>.
pub fn start_modular_multiplication(&mut self, operand_b: &T::InputType) {
self.set_up_modular_multiplication(operand_b);
self.rsa.write_modmulti_start();
}
/// Reads the result to the given buffer.
/// This is a non blocking function that returns without an error if
/// operation is completed successfully.
@ -261,6 +335,23 @@ impl<'a, 'd, T: RsaMode + Multi, DM: crate::Mode, const N: usize> RsaMultiplicat
where
T: RsaMode<InputType = [u32; N]>,
{
/// Creates an instance of `RsaMultiplication`.
pub fn new(rsa: &'a mut Rsa<'d, DM>, operand_a: &T::InputType) -> Self {
Self::write_mode(rsa);
rsa.write_operand_a(operand_a);
Self {
rsa,
phantom: PhantomData,
}
}
/// Starts the multiplication operation.
pub fn start_multiplication(&mut self, operand_b: &T::InputType) {
self.set_up_multiplication(operand_b);
self.rsa.write_multi_start();
}
/// Reads the result to the given buffer.
/// This is a non blocking function that returns without an error if
/// operation is completed successfully. `start_multiplication` must be
@ -279,59 +370,67 @@ pub(crate) mod asynch {
use core::task::Poll;
use embassy_sync::waitqueue::AtomicWaker;
use portable_atomic::{AtomicBool, Ordering};
use procmacros::handler;
use crate::rsa::{
Multi,
RsaMode,
RsaModularExponentiation,
RsaModularMultiplication,
RsaMultiplication,
use crate::{
rsa::{
Multi,
Rsa,
RsaMode,
RsaModularExponentiation,
RsaModularMultiplication,
RsaMultiplication,
},
Async,
};
static WAKER: AtomicWaker = AtomicWaker::new();
static SIGNALED: AtomicBool = AtomicBool::new(false);
/// `Future` that waits for the RSA operation to complete.
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub(crate) struct RsaFuture<'d> {
instance: &'d crate::peripherals::RSA,
struct RsaFuture<'a, 'd> {
#[cfg_attr(esp32, allow(dead_code))]
instance: &'a Rsa<'d, Async>,
}
impl<'d> RsaFuture<'d> {
/// Asynchronously initializes the RSA peripheral.
pub fn new(instance: &'d crate::peripherals::RSA) -> Self {
impl<'a, 'd> RsaFuture<'a, 'd> {
fn new(instance: &'a Rsa<'d, Async>) -> Self {
SIGNALED.store(false, Ordering::Relaxed);
cfg_if::cfg_if! {
if #[cfg(esp32)] {
instance.interrupt().modify(|_, w| w.interrupt().set_bit());
} else if #[cfg(any(esp32s2, esp32s3))] {
instance.interrupt_ena().modify(|_, w| w.interrupt_ena().set_bit());
instance.rsa.interrupt_ena().write(|w| w.interrupt_ena().set_bit());
} else {
instance.int_ena().modify(|_, w| w.int_ena().set_bit());
instance.rsa.int_ena().write(|w| w.int_ena().set_bit());
}
}
Self { instance }
}
fn event_bit_is_clear(&self) -> bool {
fn is_done(&self) -> bool {
SIGNALED.load(Ordering::Acquire)
}
}
impl Drop for RsaFuture<'_, '_> {
fn drop(&mut self) {
cfg_if::cfg_if! {
if #[cfg(esp32)] {
self.instance.interrupt().read().interrupt().bit_is_clear()
} else if #[cfg(any(esp32s2, esp32s3))] {
self
.instance
.interrupt_ena()
.read()
.interrupt_ena()
.bit_is_clear()
self.instance.rsa.interrupt_ena().write(|w| w.interrupt_ena().clear_bit());
} else {
self.instance.int_ena().read().int_ena().bit_is_clear()
self.instance.rsa.int_ena().write(|w| w.int_ena().clear_bit());
}
}
}
}
impl<'d> core::future::Future for RsaFuture<'d> {
impl core::future::Future for RsaFuture<'_, '_> {
type Output = ();
fn poll(
@ -339,7 +438,7 @@ pub(crate) mod asynch {
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
WAKER.register(cx.waker());
if self.event_bit_is_clear() {
if self.is_done() {
Poll::Ready(())
} else {
Poll::Pending
@ -347,7 +446,7 @@ pub(crate) mod asynch {
}
}
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T, crate::Async>
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T, Async>
where
T: RsaMode<InputType = [u32; N]>,
{
@ -358,49 +457,47 @@ pub(crate) mod asynch {
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_exponentiation(base, r);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
self.set_up_exponentiation(base, r);
let fut = RsaFuture::new(self.rsa);
self.rsa.write_modexp_start();
fut.await;
self.rsa.read_out(outbuf);
}
}
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T, crate::Async>
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T, Async>
where
T: RsaMode<InputType = [u32; N]>,
{
#[cfg(not(esp32))]
/// Asynchronously performs an RSA modular multiplication operation.
pub async fn modular_multiplication(
&mut self,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_modular_multiplication(r);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
#[cfg(esp32)]
/// Asynchronously performs an RSA modular multiplication operation.
pub async fn modular_multiplication(
&mut self,
operand_a: &T::InputType,
operand_b: &T::InputType,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_step1(operand_a, r);
self.start_step2(operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
cfg_if::cfg_if! {
if #[cfg(esp32)] {
let fut = RsaFuture::new(self.rsa);
self.rsa.write_multi_start();
fut.await;
self.rsa.write_operand_a(operand_b);
} else {
self.set_up_modular_multiplication(operand_b);
}
}
let fut = RsaFuture::new(self.rsa);
self.rsa.write_modmulti_start();
fut.await;
self.rsa.read_out(outbuf);
}
}
impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T, crate::Async>
impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T, Async>
where
T: RsaMode<InputType = [u32; N]>,
{
#[cfg(not(esp32))]
/// Asynchronously performs an RSA multiplication operation.
pub async fn multiplication<'b, const O: usize>(
&mut self,
@ -409,44 +506,28 @@ pub(crate) mod asynch {
) where
T: Multi<OutputType = [u32; O]>,
{
self.start_multiplication(operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
#[cfg(esp32)]
/// Asynchronously performs an RSA multiplication operation.
pub async fn multiplication<'b, const O: usize>(
&mut self,
operand_a: &T::InputType,
operand_b: &T::InputType,
outbuf: &mut T::OutputType,
) where
T: Multi<OutputType = [u32; O]>,
{
self.start_multiplication(operand_a, operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
self.set_up_multiplication(operand_b);
let fut = RsaFuture::new(self.rsa);
self.rsa.write_multi_start();
fut.await;
self.rsa.read_out(outbuf);
}
}
#[handler]
/// Interrupt handler for RSA.
pub(super) fn rsa_interrupt_handler() {
#[cfg(not(any(esp32, esp32s2, esp32s3)))]
unsafe { &*crate::peripherals::RSA::ptr() }
.int_ena()
.modify(|_, w| w.int_ena().clear_bit());
#[cfg(esp32)]
unsafe { &*crate::peripherals::RSA::ptr() }
.interrupt()
.modify(|_, w| w.interrupt().clear_bit());
#[cfg(any(esp32s2, esp32s3))]
unsafe { &*crate::peripherals::RSA::ptr() }
.interrupt_ena()
.modify(|_, w| w.interrupt_ena().clear_bit());
let rsa = unsafe { &*crate::peripherals::RSA::ptr() };
SIGNALED.store(true, Ordering::Release);
cfg_if::cfg_if! {
if #[cfg(esp32)] {
rsa.interrupt().write(|w| w.interrupt().set_bit());
} else if #[cfg(any(esp32s2, esp32s3))] {
rsa.clear_interrupt().write(|w| w.clear_interrupt().set_bit());
} else {
rsa.int_clr().write(|w| w.clear_interrupt().set_bit());
}
}
WAKER.wake();
}

View File

@ -107,6 +107,10 @@ harness = false
name = "rsa"
harness = false
[[test]]
name = "rsa_async"
harness = false
[[test]]
name = "sha"
harness = false

View File

@ -10,7 +10,7 @@ use esp_hal::{
peripherals::Peripherals,
prelude::*,
rsa::{
operand_sizes,
operand_sizes::*,
Rsa,
RsaModularExponentiation,
RsaModularMultiplication,
@ -37,16 +37,6 @@ struct Context<'a> {
rsa: Rsa<'a, Blocking>,
}
impl Context<'_> {
pub fn init() -> Self {
let peripherals = Peripherals::take();
let mut rsa = Rsa::new(peripherals.RSA);
nb::block!(rsa.ready()).unwrap();
Context { rsa }
}
}
const fn compute_r(modulus: &U512) -> U512 {
let mut d = [0_u32; U512::LIMBS * 2 + 1];
d[d.len() - 1] = 1;
@ -68,10 +58,15 @@ mod tests {
#[init]
fn init() -> Context<'static> {
Context::init()
let peripherals = Peripherals::take();
let mut rsa = Rsa::new(peripherals.RSA);
nb::block!(rsa.ready()).unwrap();
Context { rsa }
}
#[test]
#[timeout(5)]
fn test_modular_exponentiation(mut ctx: Context<'static>) {
const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [
1601059419, 3994655875, 2600857657, 1530060852, 64828275, 4221878473, 2751381085,
@ -85,20 +80,20 @@ mod tests {
ctx.rsa.enable_disable_search_acceleration(true);
}
let mut outbuf = [0_u32; U512::LIMBS];
let mut mod_exp = RsaModularExponentiation::<operand_sizes::Op512, esp_hal::Blocking>::new(
let mut mod_exp = RsaModularExponentiation::<Op512, _>::new(
&mut ctx.rsa,
BIGNUM_2.as_words(),
BIGNUM_3.as_words(),
compute_mprime(&BIGNUM_3),
);
let r = compute_r(&BIGNUM_3);
let base = &BIGNUM_1.as_words();
mod_exp.start_exponentiation(&base, r.as_words());
mod_exp.start_exponentiation(BIGNUM_1.as_words(), r.as_words());
mod_exp.read_results(&mut outbuf);
assert_eq!(EXPECTED_OUTPUT, outbuf);
}
#[test]
#[timeout(5)]
fn test_modular_multiplication(mut ctx: Context<'static>) {
const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [
1868256644, 833470784, 4187374062, 2684021027, 191862388, 1279046003, 1929899870,
@ -107,31 +102,21 @@ mod tests {
];
let mut outbuf = [0_u32; U512::LIMBS];
let mut mod_multi =
RsaModularMultiplication::<operand_sizes::Op512, esp_hal::Blocking>::new(
&mut ctx.rsa,
#[cfg(not(feature = "esp32"))]
BIGNUM_1.as_words(),
#[cfg(not(feature = "esp32"))]
BIGNUM_2.as_words(),
BIGNUM_3.as_words(),
compute_mprime(&BIGNUM_3),
);
let r = compute_r(&BIGNUM_3);
#[cfg(feature = "esp32")]
{
mod_multi.start_step1(BIGNUM_1.as_words(), r.as_words());
mod_multi.start_step2(BIGNUM_2.as_words());
}
#[cfg(not(feature = "esp32"))]
{
mod_multi.start_modular_multiplication(r.as_words());
}
let mut mod_multi = RsaModularMultiplication::<Op512, _>::new(
&mut ctx.rsa,
BIGNUM_1.as_words(),
BIGNUM_3.as_words(),
r.as_words(),
compute_mprime(&BIGNUM_3),
);
mod_multi.start_modular_multiplication(BIGNUM_2.as_words());
mod_multi.read_results(&mut outbuf);
assert_eq!(EXPECTED_OUTPUT, outbuf);
}
#[test]
#[timeout(5)]
fn test_multiplication(mut ctx: Context<'static>) {
const EXPECTED_OUTPUT: [u32; U1024::LIMBS] = [
1264702968, 3552243420, 2602501218, 498422249, 2431753435, 2307424767, 349202767,
@ -145,21 +130,10 @@ mod tests {
let operand_a = BIGNUM_1.as_words();
let operand_b = BIGNUM_2.as_words();
cfg_if::cfg_if! {
if #[cfg(feature = "esp32")] {
let mut rsamulti =
RsaMultiplication::<operand_sizes::Op512, esp_hal::Blocking>::new(&mut ctx.rsa);
rsamulti.start_multiplication(operand_a, operand_b);
rsamulti.read_results(&mut outbuf);
} else {
let mut rsamulti = RsaMultiplication::<operand_sizes::Op512, esp_hal::Blocking>::new(
&mut ctx.rsa,
operand_a,
);
rsamulti.start_multiplication(operand_b);
rsamulti.read_results(&mut outbuf);
}
}
let mut rsamulti = RsaMultiplication::<Op512, _>::new(&mut ctx.rsa, operand_a);
rsamulti.start_multiplication(operand_b);
rsamulti.read_results(&mut outbuf);
assert_eq!(EXPECTED_OUTPUT, outbuf)
}
}

140
hil-test/tests/rsa_async.rs Normal file
View File

@ -0,0 +1,140 @@
//! Async RSA Test
//% CHIPS: esp32 esp32c3 esp32c6 esp32h2 esp32s2 esp32s3
#![no_std]
#![no_main]
use crypto_bigint::{Uint, U1024, U512};
use esp_hal::{
peripherals::Peripherals,
prelude::*,
rsa::{
operand_sizes::*,
Rsa,
RsaModularExponentiation,
RsaModularMultiplication,
RsaMultiplication,
},
Async,
};
use hil_test as _;
const BIGNUM_1: U512 = Uint::from_be_hex(
"c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\
7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878",
);
const BIGNUM_2: U512 = Uint::from_be_hex(
"1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\
6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51",
);
const BIGNUM_3: U512 = Uint::from_be_hex(
"6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\
67aeafa9ba25feb8712f188cb139b7d9b9af1c361",
);
struct Context<'a> {
rsa: Rsa<'a, Async>,
}
const fn compute_r(modulus: &U512) -> U512 {
let mut d = [0_u32; U512::LIMBS * 2 + 1];
d[d.len() - 1] = 1;
let d = Uint::from_words(d);
d.const_rem(&modulus.resize()).0.resize()
}
const fn compute_mprime(modulus: &U512) -> u32 {
let m_inv = modulus.inv_mod2k(32).to_words()[0];
(-1 * m_inv as i64 % 4294967296) as u32
}
#[cfg(test)]
#[embedded_test::tests(executor = esp_hal_embassy::Executor::new())]
mod tests {
use defmt::assert_eq;
use super::*;
#[init]
fn init() -> Context<'static> {
let peripherals = Peripherals::take();
let mut rsa = Rsa::new_async(peripherals.RSA);
nb::block!(rsa.ready()).unwrap();
Context { rsa }
}
#[test]
#[timeout(5)]
async fn modular_exponentiation(mut ctx: Context<'static>) {
const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [
1601059419, 3994655875, 2600857657, 1530060852, 64828275, 4221878473, 2751381085,
1938128086, 625895085, 2087010412, 2133352910, 101578249, 3798099415, 3357588690,
2065243474, 330914193,
];
#[cfg(not(feature = "esp32"))]
{
ctx.rsa.enable_disable_constant_time_acceleration(true);
ctx.rsa.enable_disable_search_acceleration(true);
}
let mut outbuf = [0_u32; U512::LIMBS];
let mut mod_exp = RsaModularExponentiation::<Op512, _>::new(
&mut ctx.rsa,
BIGNUM_2.as_words(),
BIGNUM_3.as_words(),
compute_mprime(&BIGNUM_3),
);
let r = compute_r(&BIGNUM_3);
mod_exp
.exponentiation(BIGNUM_1.as_words(), r.as_words(), &mut outbuf)
.await;
assert_eq!(EXPECTED_OUTPUT, outbuf);
}
#[test]
#[timeout(5)]
async fn test_modular_multiplication(mut ctx: Context<'static>) {
const EXPECTED_OUTPUT: [u32; U512::LIMBS] = [
1868256644, 833470784, 4187374062, 2684021027, 191862388, 1279046003, 1929899870,
4209598061, 3830489207, 1317083344, 2666864448, 3701382766, 3232598924, 2904609522,
747558855, 479377985,
];
let mut outbuf = [0_u32; U512::LIMBS];
let r = compute_r(&BIGNUM_3);
let mut mod_multi = RsaModularMultiplication::<Op512, _>::new(
&mut ctx.rsa,
BIGNUM_1.as_words(),
BIGNUM_3.as_words(),
r.as_words(),
compute_mprime(&BIGNUM_3),
);
mod_multi
.modular_multiplication(BIGNUM_2.as_words(), &mut outbuf)
.await;
assert_eq!(EXPECTED_OUTPUT, outbuf);
}
#[test]
#[timeout(5)]
async fn test_multiplication(mut ctx: Context<'static>) {
const EXPECTED_OUTPUT: [u32; U1024::LIMBS] = [
1264702968, 3552243420, 2602501218, 498422249, 2431753435, 2307424767, 349202767,
2269697177, 1525551459, 3623276361, 3146383138, 191420847, 4252021895, 9176459,
301757643, 4220806186, 434407318, 3722444851, 1850128766, 928651940, 107896699,
563405838, 1834067613, 1289630401, 3145128058, 3300293535, 3077505758, 1926648662,
1264151247, 3626086486, 3701894076, 306518743,
];
let mut outbuf = [0_u32; U1024::LIMBS];
let operand_a = BIGNUM_1.as_words();
let operand_b = BIGNUM_2.as_words();
let mut rsamulti = RsaMultiplication::<Op512, _>::new(&mut ctx.rsa, operand_a);
rsamulti.multiplication(operand_b, &mut outbuf).await;
assert_eq!(EXPECTED_OUTPUT, outbuf)
}
}