Adding async support for RSA peripheral (#790)

* Adding async support for RSA peripheral

* Add esp32 support (doesn't work properly yet)

* Xtensa chips are supported (except of esp32)

Add modular multiplication for esp32

Adding a CHANGELOG entry

Rebase issue fix

* Code cleanup

* Add `.await` on `RsaFuture::new()` calls

* Refactor and rebase

Made `read_results` functions to be `async`, got rid of `nb` usage

* Change API methods naming + refactor `start_step2` method

* Adjust example to the API change + documentation

* Code cleaning + refactoring

Update examples
This commit is contained in:
Kirill Mikhailov 2023-09-27 18:03:06 +02:00 committed by GitHub
parent ae160d66c3
commit 24c5e8cb79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 258 additions and 68 deletions

View File

@ -66,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `defmt` feature to enable log output (#773)
- A new macro to load LP core code on ESP32-C6 (#779)
- Add `ECC`` peripheral driver (#785)
- Adding async support for RSA peripheral(doesn't work properly for `esp32` chip - issue will be created)(#790)
### Changed

View File

@ -127,16 +127,17 @@ where
/// This is a non blocking function that returns without an error if
/// operation is completed successfully. `start_step1` must be called
/// before calling this function.
pub fn start_step2(&mut self, operand_b: &T::InputType) -> nb::Result<(), Infallible> {
if !self.rsa.is_idle() {
return Err(nb::Error::WouldBlock);
pub fn start_step2(&mut self, operand_b: &T::InputType) {
loop {
if self.rsa.is_idle() {
self.rsa.clear_interrupt();
unsafe {
self.rsa.write_operand_a(operand_b);
}
self.set_start();
break;
}
}
self.rsa.clear_interrupt();
unsafe {
self.rsa.write_operand_a(operand_b);
}
self.set_start();
Ok(())
}
fn set_start(&mut self) {

View File

@ -27,6 +27,31 @@
//!
//! let mut rsa = Rsa::new(peripherals.RSA);
//! ```
//!
//! ### Async (modular exponentiation)
//! ```no_run
//! #[embassy_executor::task]
//! async fn mod_exp_example(mut rsa: Rsa<'static>) {
//! let mut outbuf = [0_u8; U512::BYTES];
//! let mut mod_exp = RsaModularExponentiation::<operand_sizes::Op512>::new(
//! &mut rsa,
//! &BIGNUM_2.to_le_bytes(),
//! &BIGNUM_3.to_le_bytes(),
//! compute_mprime(&BIGNUM_3),
//! );
//! let r = compute_r(&BIGNUM_3).to_le_bytes();
//! let base = &BIGNUM_1.to_le_bytes();
//! mod_exp.exponentiation(&base, &r, &mut outbuf).await;
//! let residue_params = DynResidueParams::new(&BIGNUM_3);
//! let residue = DynResidue::new(&BIGNUM_1, residue_params);
//! let sw_out = residue.pow(&BIGNUM_2);
//! assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve());
//! println!("modular exponentiation done");
//! }
//! ```
//! This peripheral supports `async` on every available chip except of `esp32`
//! (to be solved).
//!
//! ⚠️: The examples for RSA peripheral are quite extensive, so for a more
//! detailed study of how to use this driver please visit [the repository
@ -35,7 +60,7 @@
//! [nb]: https://docs.rs/nb/1.1.0/nb/
//! [the repository with corresponding example]: https://github.com/esp-rs/esp-hal/blob/main/esp32-hal/examples/rsa.rs
use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping};
use core::{marker::PhantomData, ptr::copy_nonoverlapping};
use crate::{
peripheral::{Peripheral, PeripheralRef},
@ -169,15 +194,16 @@ where
/// This is a non blocking function that returns without an error if
/// operation is completed successfully. `start_exponentiation` must be
/// called before calling this function.
pub fn read_results(&mut self, outbuf: &mut T::InputType) -> nb::Result<(), Infallible> {
if !self.rsa.is_idle() {
return Err(nb::Error::WouldBlock);
pub fn read_results(&mut self, outbuf: &mut T::InputType) {
loop {
if self.rsa.is_idle() {
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
break;
}
}
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
Ok(())
}
}
@ -197,15 +223,16 @@ where
/// Reads the result to the given buffer.
/// This is a non blocking function that returns without an error if
/// operation is completed successfully.
pub fn read_results(&mut self, outbuf: &mut T::InputType) -> nb::Result<(), Infallible> {
if !self.rsa.is_idle() {
return Err(nb::Error::WouldBlock);
pub fn read_results(&mut self, outbuf: &mut T::InputType) {
loop {
if self.rsa.is_idle() {
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
break;
}
}
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
Ok(())
}
}
@ -226,20 +253,187 @@ where
/// This is a non blocking function that returns without an error if
/// operation is completed successfully. `start_multiplication` must be
/// called before calling this function.
pub fn read_results<'b, const O: usize>(
&mut self,
outbuf: &mut T::OutputType,
) -> nb::Result<(), Infallible>
pub fn read_results<'b, const O: usize>(&mut self, outbuf: &mut T::OutputType)
where
T: Multi<OutputType = [u8; O]>,
{
if !self.rsa.is_idle() {
return Err(nb::Error::WouldBlock);
loop {
if self.rsa.is_idle() {
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
break;
}
}
unsafe {
self.rsa.read_out(outbuf);
}
self.rsa.clear_interrupt();
Ok(())
}
}
#[cfg(feature = "async")]
pub(crate) mod asynch {
use core::task::Poll;
use embassy_sync::waitqueue::AtomicWaker;
use procmacros::interrupt;
use crate::rsa::{
Multi,
RsaMode,
RsaModularExponentiation,
RsaModularMultiplication,
RsaMultiplication,
};
static WAKER: AtomicWaker = AtomicWaker::new();
pub(crate) struct RsaFuture<'d> {
instance: &'d crate::peripherals::RSA,
}
impl<'d> RsaFuture<'d> {
pub async fn new(instance: &'d crate::peripherals::RSA) -> Self {
#[cfg(not(any(esp32, esp32s2, esp32s3)))]
instance.int_ena.modify(|_, w| w.int_ena().set_bit());
#[cfg(any(esp32s2, esp32s3))]
instance
.interrupt_ena
.modify(|_, w| w.interrupt_ena().set_bit());
#[cfg(esp32)]
instance.interrupt.modify(|_, w| w.interrupt().set_bit());
Self { instance }
}
fn event_bit_is_clear(&self) -> bool {
#[cfg(not(any(esp32, esp32s2, esp32s3)))]
return self.instance.int_ena.read().int_ena().bit_is_clear();
#[cfg(any(esp32s2, esp32s3))]
return self
.instance
.interrupt_ena
.read()
.interrupt_ena()
.bit_is_clear();
#[cfg(esp32)]
return self.instance.interrupt.read().interrupt().bit_is_clear();
}
}
impl<'d> core::future::Future for RsaFuture<'d> {
type Output = ();
fn poll(
self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
WAKER.register(cx.waker());
if self.event_bit_is_clear() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T>
where
T: RsaMode<InputType = [u8; N]>,
{
pub async fn exponentiation(
&mut self,
base: &T::InputType,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_exponentiation(&base, &r);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
}
impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T>
where
T: RsaMode<InputType = [u8; N]>,
{
#[cfg(not(esp32))]
pub async fn modular_multiplication(
&mut self,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_modular_multiplication(r);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
#[cfg(esp32)]
pub async fn modular_multiplication(
&mut self,
operand_a: &T::InputType,
operand_b: &T::InputType,
r: &T::InputType,
outbuf: &mut T::InputType,
) {
self.start_step1(operand_a, r);
self.start_step2(operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
}
impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T>
where
T: RsaMode<InputType = [u8; N]>,
{
#[cfg(not(esp32))]
pub async fn multiplication<'b, const O: usize>(
&mut self,
operand_b: &T::InputType,
outbuf: &mut T::OutputType,
) where
T: Multi<OutputType = [u8; O]>,
{
self.start_multiplication(operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
#[cfg(esp32)]
pub async fn multiplication<'b, const O: usize>(
&mut self,
operand_a: &T::InputType,
operand_b: &T::InputType,
outbuf: &mut T::OutputType,
) where
T: Multi<OutputType = [u8; O]>,
{
self.start_multiplication(operand_a, operand_b);
RsaFuture::new(&self.rsa.rsa).await;
self.read_results(outbuf);
}
}
#[interrupt]
fn RSA() {
#[cfg(not(any(esp32, esp32s2, esp32s3)))]
unsafe { &*crate::peripherals::RSA::ptr() }
.int_ena
.modify(|_, w| w.int_ena().clear_bit());
#[cfg(esp32)]
unsafe { &*crate::peripherals::RSA::ptr() }
.interrupt
.modify(|_, w| w.interrupt().clear_bit());
#[cfg(any(esp32s2, esp32s3))]
unsafe { &*crate::peripherals::RSA::ptr() }
.interrupt_ena
.modify(|_, w| w.interrupt_ena().clear_bit());
WAKER.wake();
}
}

View File

@ -26,7 +26,6 @@ use esp32_hal::{
};
use esp_backtrace as _;
use esp_println::println;
use nb::block;
const BIGNUM_1: U512 = Uint::from_be_hex(
"c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\
@ -61,7 +60,7 @@ fn main() -> ! {
let rsa = peripherals.RSA;
let mut rsa = Rsa::new(rsa);
block!(rsa.ready()).unwrap();
nb::block!(rsa.ready()).unwrap();
mod_exp_example(&mut rsa);
mod_multi_example(&mut rsa);
multiplication_example(&mut rsa);
@ -78,8 +77,8 @@ fn mod_multi_example(rsa: &mut Rsa) {
let r = compute_r(&BIGNUM_3).to_le_bytes();
let pre_hw_modmul = xtensa_lx::timer::get_cycle_count();
mod_multi.start_step1(&BIGNUM_1.to_le_bytes(), &r);
block!(mod_multi.start_step2(&BIGNUM_2.to_le_bytes())).unwrap();
block!(mod_multi.read_results(&mut outbuf)).unwrap();
mod_multi.start_step2(&BIGNUM_2.to_le_bytes());
mod_multi.read_results(&mut outbuf);
let post_hw_modmul = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw modular multiplication",
@ -112,7 +111,7 @@ fn mod_exp_example(rsa: &mut Rsa) {
let base = &BIGNUM_1.to_le_bytes();
let pre_hw_exp = xtensa_lx::timer::get_cycle_count();
mod_exp.start_exponentiation(base, &r);
block!(mod_exp.read_results(&mut outbuf)).unwrap();
mod_exp.read_results(&mut outbuf);
let post_hw_exp = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw modular exponentiation",
@ -138,7 +137,7 @@ fn multiplication_example(rsa: &mut Rsa) {
let operand_b = &BIGNUM_2.to_le_bytes();
let pre_hw_mul = xtensa_lx::timer::get_cycle_count();
rsamulti.start_multiplication(&operand_a, &operand_b);
block!(rsamulti.read_results(&mut out)).unwrap();
rsamulti.read_results(&mut out);
let post_hw_mul = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw multiplication",

View File

@ -26,7 +26,6 @@ use esp32c3_hal::{
};
use esp_backtrace as _;
use esp_println::println;
use nb::block;
const BIGNUM_1: U512 = Uint::from_be_hex(
"c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\
@ -61,7 +60,7 @@ fn main() -> ! {
let mut rsa = Rsa::new(peripherals.RSA);
block!(rsa.ready()).unwrap();
nb::block!(rsa.ready()).unwrap();
mod_exp_example(&mut rsa);
mod_multi_example(&mut rsa);
multiplication_example(&mut rsa);
@ -80,7 +79,7 @@ fn mod_multi_example(rsa: &mut Rsa) {
let r = compute_r(&BIGNUM_3).to_le_bytes();
let pre_hw_modmul = SystemTimer::now();
mod_multi.start_modular_multiplication(&r);
block!(mod_multi.read_results(&mut outbuf)).unwrap();
mod_multi.read_results(&mut outbuf);
let post_hw_modmul = SystemTimer::now();
println!(
"it took {} cycles for hw modular multiplication",
@ -114,7 +113,7 @@ fn mod_exp_example(rsa: &mut Rsa) {
let base = &BIGNUM_1.to_le_bytes();
let pre_hw_exp = SystemTimer::now();
mod_exp.start_exponentiation(&base, &r);
block!(mod_exp.read_results(&mut outbuf)).unwrap();
mod_exp.read_results(&mut outbuf);
let post_hw_exp = SystemTimer::now();
println!(
"it took {} cycles for hw modular exponentiation",
@ -140,7 +139,7 @@ fn multiplication_example(rsa: &mut Rsa) {
let mut rsamulti = RsaMultiplication::<operand_sizes::Op512>::new(rsa, &operand_a);
let pre_hw_mul = SystemTimer::now();
rsamulti.start_multiplication(&operand_b);
block!(rsamulti.read_results(&mut out)).unwrap();
rsamulti.read_results(&mut out);
let post_hw_mul = SystemTimer::now();
println!(
"it took {} cycles for hw multiplication",

View File

@ -26,7 +26,6 @@ use esp32c6_hal::{
};
use esp_backtrace as _;
use esp_println::println;
use nb::block;
const BIGNUM_1: U512 = Uint::from_be_hex(
"c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\
@ -61,7 +60,7 @@ fn main() -> ! {
let mut rsa = Rsa::new(peripherals.RSA);
block!(rsa.ready()).unwrap();
nb::block!(rsa.ready()).unwrap();
mod_exp_example(&mut rsa);
mod_multi_example(&mut rsa);
multiplication_example(&mut rsa);
@ -80,7 +79,7 @@ fn mod_multi_example(rsa: &mut Rsa) {
let r = compute_r(&BIGNUM_3).to_le_bytes();
let pre_hw_modmul = SystemTimer::now();
mod_multi.start_modular_multiplication(&r);
block!(mod_multi.read_results(&mut outbuf)).unwrap();
mod_multi.read_results(&mut outbuf);
let post_hw_modmul = SystemTimer::now();
println!(
"it took {} cycles for hw modular multiplication",
@ -114,7 +113,7 @@ fn mod_exp_example(rsa: &mut Rsa) {
let base = &BIGNUM_1.to_le_bytes();
let pre_hw_exp = SystemTimer::now();
mod_exp.start_exponentiation(&base, &r);
block!(mod_exp.read_results(&mut outbuf)).unwrap();
mod_exp.read_results(&mut outbuf);
let post_hw_exp = SystemTimer::now();
println!(
"it took {} cycles for hw modular exponentiation",
@ -140,7 +139,7 @@ fn multiplication_example(rsa: &mut Rsa) {
let mut rsamulti = RsaMultiplication::<operand_sizes::Op512>::new(rsa, &operand_a);
let pre_hw_mul = SystemTimer::now();
rsamulti.start_multiplication(&operand_b);
block!(rsamulti.read_results(&mut out)).unwrap();
rsamulti.read_results(&mut out);
let post_hw_mul = SystemTimer::now();
println!(
"it took {} cycles for hw multiplication",

View File

@ -26,7 +26,6 @@ use esp32h2_hal::{
};
use esp_backtrace as _;
use esp_println::println;
use nb::block;
const BIGNUM_1: U512 = Uint::from_be_hex(
"c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\
@ -61,7 +60,7 @@ fn main() -> ! {
let mut rsa = Rsa::new(peripherals.RSA);
block!(rsa.ready()).unwrap();
nb::block!(rsa.ready()).unwrap();
mod_exp_example(&mut rsa);
mod_multi_example(&mut rsa);
multiplication_example(&mut rsa);
@ -80,7 +79,7 @@ fn mod_multi_example(rsa: &mut Rsa) {
let r = compute_r(&BIGNUM_3).to_le_bytes();
let pre_hw_modmul = SystemTimer::now();
mod_multi.start_modular_multiplication(&r);
block!(mod_multi.read_results(&mut outbuf)).unwrap();
mod_multi.read_results(&mut outbuf);
let post_hw_modmul = SystemTimer::now();
println!(
"it took {} cycles for hw modular multiplication",
@ -114,7 +113,7 @@ fn mod_exp_example(rsa: &mut Rsa) {
let base = &BIGNUM_1.to_le_bytes();
let pre_hw_exp = SystemTimer::now();
mod_exp.start_exponentiation(&base, &r);
block!(mod_exp.read_results(&mut outbuf)).unwrap();
mod_exp.read_results(&mut outbuf);
let post_hw_exp = SystemTimer::now();
println!(
"it took {} cycles for hw modular exponentiation",
@ -140,7 +139,7 @@ fn multiplication_example(rsa: &mut Rsa) {
let mut rsamulti = RsaMultiplication::<operand_sizes::Op512>::new(rsa, &operand_a);
let pre_hw_mul = SystemTimer::now();
rsamulti.start_multiplication(&operand_b);
block!(rsamulti.read_results(&mut out)).unwrap();
rsamulti.read_results(&mut out);
let post_hw_mul = SystemTimer::now();
println!(
"it took {} cycles for hw multiplication",

View File

@ -26,7 +26,6 @@ use esp32s2_hal::{
};
use esp_backtrace as _;
use esp_println::println;
use nb::block;
const BIGNUM_1: U512 = Uint::from_be_hex(
"c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\
@ -61,7 +60,7 @@ fn main() -> ! {
let mut rsa = Rsa::new(peripherals.RSA);
block!(rsa.ready()).unwrap();
nb::block!(rsa.ready()).unwrap();
mod_exp_example(&mut rsa);
mod_multi_example(&mut rsa);
multiplication_example(&mut rsa);
@ -81,7 +80,7 @@ fn mod_multi_example(rsa: &mut Rsa) {
let r = compute_r(&BIGNUM_3).to_le_bytes();
let pre_hw_modmul = xtensa_lx::timer::get_cycle_count();
mod_multi.start_modular_multiplication(&r);
block!(mod_multi.read_results(&mut outbuf)).unwrap();
mod_multi.read_results(&mut outbuf);
let post_hw_modmul = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw modular multiplication",
@ -115,7 +114,7 @@ fn mod_exp_example(rsa: &mut Rsa) {
let base = &BIGNUM_1.to_le_bytes();
let pre_hw_exp = xtensa_lx::timer::get_cycle_count();
mod_exp.start_exponentiation(&base, &r);
block!(mod_exp.read_results(&mut outbuf)).unwrap();
mod_exp.read_results(&mut outbuf);
let post_hw_exp = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw modular exponentiation",
@ -141,7 +140,7 @@ fn multiplication_example(rsa: &mut Rsa) {
let mut rsamulti = RsaMultiplication::<operand_sizes::Op512>::new(rsa, &operand_a);
let pre_hw_mul = xtensa_lx::timer::get_cycle_count();
rsamulti.start_multiplication(&operand_b);
block!(rsamulti.read_results(&mut out)).unwrap();
rsamulti.read_results(&mut out);
let post_hw_mul = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw multiplication",

View File

@ -26,7 +26,6 @@ use esp32s3_hal::{
};
use esp_backtrace as _;
use esp_println::println;
use nb::block;
const BIGNUM_1: U512 = Uint::from_be_hex(
"c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\
@ -61,7 +60,7 @@ fn main() -> ! {
let mut rsa = Rsa::new(peripherals.RSA);
block!(rsa.ready()).unwrap();
nb::block!(rsa.ready()).unwrap();
mod_exp_example(&mut rsa);
mod_multi_example(&mut rsa);
multiplication_example(&mut rsa);
@ -81,7 +80,7 @@ fn mod_multi_example(rsa: &mut Rsa) {
let r = compute_r(&BIGNUM_3).to_le_bytes();
let pre_hw_modmul = xtensa_lx::timer::get_cycle_count();
mod_multi.start_modular_multiplication(&r);
block!(mod_multi.read_results(&mut outbuf)).unwrap();
mod_multi.read_results(&mut outbuf);
let post_hw_modmul = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw modular multiplication",
@ -115,7 +114,7 @@ fn mod_exp_example(rsa: &mut Rsa) {
let base = &BIGNUM_1.to_le_bytes();
let pre_hw_exp = xtensa_lx::timer::get_cycle_count();
mod_exp.start_exponentiation(&base, &r);
block!(mod_exp.read_results(&mut outbuf)).unwrap();
mod_exp.read_results(&mut outbuf);
let post_hw_exp = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw modular exponentiation",
@ -141,7 +140,7 @@ fn multiplication_example(rsa: &mut Rsa) {
let mut rsamulti = RsaMultiplication::<operand_sizes::Op512>::new(rsa, &operand_a);
let pre_hw_mul = xtensa_lx::timer::get_cycle_count();
rsamulti.start_multiplication(&operand_b);
block!(rsamulti.read_results(&mut out)).unwrap();
rsamulti.read_results(&mut out);
let post_hw_mul = xtensa_lx::timer::get_cycle_count();
println!(
"it took {} cycles for hw multiplication",