diff --git a/CHANGELOG.md b/CHANGELOG.md index 99a2a6778..6c950c433 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add initial support for the ESP32-H2 (#513) - Add bare-bones PSRAM support for ESP32-S3 (#517) - Add async support to the I2C driver (#519) +- Add initial support for RSA in ESP32-H2 (#526) - Add initial support for SHA in ESP32-H2 (#527) - Add initial support for AES in ESP32-H2 (#528) - Add blinky_erased_pins example for ESP32-H2 (#530) diff --git a/esp-hal-common/devices/esp32h2.toml b/esp-hal-common/devices/esp32h2.toml index d4034ce67..6eaf69b59 100644 --- a/esp-hal-common/devices/esp32h2.toml +++ b/esp-hal-common/devices/esp32h2.toml @@ -41,7 +41,7 @@ peripherals = [ # "pmu", # "rmt", # "rng", - # "rsa", + "rsa", "sha", # "soc_etm", # "spi0", diff --git a/esp-hal-common/src/rsa/mod.rs b/esp-hal-common/src/rsa/mod.rs index b02673738..258ae70cb 100644 --- a/esp-hal-common/src/rsa/mod.rs +++ b/esp-hal-common/src/rsa/mod.rs @@ -28,6 +28,7 @@ use crate::{ #[cfg_attr(esp32s3, path = "esp32sX.rs")] #[cfg_attr(esp32c3, path = "esp32cX.rs")] #[cfg_attr(esp32c6, path = "esp32cX.rs")] +#[cfg_attr(esp32h2, path = "esp32cX.rs")] #[cfg_attr(esp32, path = "esp32.rs")] mod rsa_spec_impl; diff --git a/esp-hal-common/src/soc/esp32h2/peripherals.rs b/esp-hal-common/src/soc/esp32h2/peripherals.rs index b9ae314ad..2e790f1be 100644 --- a/esp-hal-common/src/soc/esp32h2/peripherals.rs +++ b/esp-hal-common/src/soc/esp32h2/peripherals.rs @@ -43,7 +43,7 @@ crate::peripherals! { // PMU => true, // RMT => true, // RNG => true, - // RSA => true, + RSA => true, SHA => true, // SOC_ETM => true, // SPI0 => true, diff --git a/esp32c6-hal/examples/rsa.rs b/esp32c6-hal/examples/rsa.rs index 6c4e69ae9..a69f79c33 100644 --- a/esp32c6-hal/examples/rsa.rs +++ b/esp32c6-hal/examples/rsa.rs @@ -60,8 +60,6 @@ fn main() -> ! { let mut system = peripherals.PCR.split(); let clocks = ClockControl::boot_defaults(system.clock_control).freeze(); - // Disable the watchdog timers. For the ESP32-C6, this includes the Super WDT, - // and the TIMG WDTs. // Disable the watchdog timers. For the ESP32-C6, this includes the Super WDT, // and the TIMG WDTs. let mut rtc = Rtc::new(peripherals.LP_CLKRST); @@ -83,11 +81,6 @@ fn main() -> ! { wdt0.disable(); wdt1.disable(); - rtc.swd.disable(); - rtc.rwdt.disable(); - wdt0.disable(); - wdt1.disable(); - let mut rsa = Rsa::new(peripherals.RSA, &mut system.peripheral_clock_control); block!(rsa.ready()).unwrap(); diff --git a/esp32h2-hal/Cargo.toml b/esp32h2-hal/Cargo.toml index a4b5852cb..71d195dbb 100644 --- a/esp32h2-hal/Cargo.toml +++ b/esp32h2-hal/Cargo.toml @@ -37,6 +37,7 @@ esp-hal-common = { version = "0.9.0", features = ["esp32h2"], path = "../es [dev-dependencies] aes = "0.8.2" critical-section = "1.1.1" +crypto-bigint = { version = "0.5.2", default-features = false } embassy-executor = { version = "0.2.0", features = ["nightly", "integrated-timers"] } embedded-graphics = "0.7.1" esp-backtrace = { version = "0.7.0", features = ["esp32h2", "panic-handler", "exception-handler", "print-uart"] } diff --git a/esp32h2-hal/examples/rsa.rs b/esp32h2-hal/examples/rsa.rs new file mode 100644 index 000000000..63a25b932 --- /dev/null +++ b/esp32h2-hal/examples/rsa.rs @@ -0,0 +1,178 @@ +#![no_std] +#![no_main] + +use crypto_bigint::{ + modular::runtime_mod::{DynResidue, DynResidueParams}, + Encoding, + Uint, + U1024, + U512, +}; +use esp32h2_hal::{ + clock::ClockControl, + peripherals::Peripherals, + prelude::*, + rsa::{ + operand_sizes, + Rsa, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + systimer::SystemTimer, + timer::TimerGroup, + Rtc, +}; +use esp_backtrace as _; +use esp_println::println; +use nb::block; + +const BIGNUM_1: U512 = Uint::from_be_hex( + "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\ +7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878", +); +const BIGNUM_2: U512 = Uint::from_be_hex( + "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\ +6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51", +); +const BIGNUM_3: U512 = Uint::from_be_hex( + "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\ +67aeafa9ba25feb8712f188cb139b7d9b9af1c361", +); + +const fn compute_r(modulus: &U512) -> U512 { + let mut d = [0_u32; U512::LIMBS * 2 + 1]; + d[d.len() - 1] = 1; + let d = Uint::from_words(d); + d.const_rem(&modulus.resize()).0.resize() +} + +const fn compute_mprime(modulus: &U512) -> u32 { + let m_inv = modulus.inv_mod2k(32).to_words()[0]; + (-1 * m_inv as i64 % 4294967296) as u32 +} + +#[entry] +fn main() -> ! { + let peripherals = Peripherals::take(); + let mut system = peripherals.PCR.split(); + let clocks = ClockControl::boot_defaults(system.clock_control).freeze(); + + // Disable the watchdog timers. For the ESP32-H2, this includes the Super WDT, + // and the TIMG WDTs. + let mut rtc = Rtc::new(peripherals.LP_CLKRST); + let timer_group0 = TimerGroup::new( + peripherals.TIMG0, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt0 = timer_group0.wdt; + let timer_group1 = TimerGroup::new( + peripherals.TIMG1, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt1 = timer_group1.wdt; + + rtc.swd.disable(); + rtc.rwdt.disable(); + wdt0.disable(); + wdt1.disable(); + + let mut rsa = Rsa::new(peripherals.RSA, &mut system.peripheral_clock_control); + + block!(rsa.ready()).unwrap(); + mod_exp_example(&mut rsa); + mod_multi_example(&mut rsa); + multiplication_example(&mut rsa); + loop {} +} + +fn mod_multi_example(rsa: &mut Rsa) { + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_multi = RsaModularMultiplication::::new( + rsa, + &BIGNUM_1.to_le_bytes(), + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let pre_hw_modmul = SystemTimer::now(); + mod_multi.start_modular_multiplication(&r); + block!(mod_multi.read_results(&mut outbuf)).unwrap(); + let post_hw_modmul = SystemTimer::now(); + println!( + "it took {} cycles for hw modular multiplication", + post_hw_modmul - pre_hw_modmul + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue_num1 = DynResidue::new(&BIGNUM_1, residue_params); + let residue_num2 = DynResidue::new(&BIGNUM_2, residue_params); + let pre_sw_exp = SystemTimer::now(); + let sw_out = residue_num1.mul(&residue_num2); + let post_sw_exp = SystemTimer::now(); + println!( + "it took {} cycles for sw modular multiplication", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular multiplication done"); +} + +fn mod_exp_example(rsa: &mut Rsa) { + rsa.enable_disable_constant_time_acceleration(true); + rsa.enable_disable_search_acceleration(true); + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_exp = RsaModularExponentiation::::new( + rsa, + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let base = &BIGNUM_1.to_le_bytes(); + let pre_hw_exp = SystemTimer::now(); + mod_exp.start_exponentiation(&base, &r); + block!(mod_exp.read_results(&mut outbuf)).unwrap(); + let post_hw_exp = SystemTimer::now(); + println!( + "it took {} cycles for hw modular exponentiation", + post_hw_exp - pre_hw_exp + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue = DynResidue::new(&BIGNUM_1, residue_params); + let pre_sw_exp = SystemTimer::now(); + let sw_out = residue.pow(&BIGNUM_2); + let post_sw_exp = SystemTimer::now(); + println!( + "it took {} cycles for sw modular exponentiation", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular exponentiation done"); +} + +fn multiplication_example(rsa: &mut Rsa) { + let mut out = [0_u8; U1024::BYTES]; + let operand_a = &BIGNUM_1.to_le_bytes(); + let operand_b = &BIGNUM_2.to_le_bytes(); + let mut rsamulti = RsaMultiplication::::new(rsa, &operand_a); + let pre_hw_mul = SystemTimer::now(); + rsamulti.start_multiplication(&operand_b); + block!(rsamulti.read_results(&mut out)).unwrap(); + let post_hw_mul = SystemTimer::now(); + println!( + "it took {} cycles for hw multiplication", + post_hw_mul - pre_hw_mul + ); + let pre_sw_mul = SystemTimer::now(); + let sw_out = BIGNUM_1.mul_wide(&BIGNUM_2); + let post_sw_mul = SystemTimer::now(); + println!( + "it took {} cycles for sw multiplication", + post_sw_mul - pre_sw_mul + ); + assert_eq!(U1024::from_le_bytes(out), sw_out.1.concat(&sw_out.0)); + println!("multiplication done"); +}