split out connection logic a bit

This commit is contained in:
Robin Appelman
2020-09-02 19:02:47 +02:00
parent 814cdebb18
commit 68f31495bc
4 changed files with 172 additions and 143 deletions

138
src/connection.rs Normal file
View File

@@ -0,0 +1,138 @@
use crate::encoder::SlipEncoder;
use crate::error::RomError;
use crate::Error;
use bytemuck::{from_bytes, Pod, Zeroable};
use serial::SerialPort;
use slip_codec::Decoder;
use std::io::Write;
use std::thread::sleep;
use std::time::Duration;
pub struct Connection {
serial: Box<dyn SerialPort>,
decoder: Decoder,
}
#[derive(Debug, Zeroable, Pod, Copy, Clone)]
#[repr(C)]
#[repr(packed)]
pub struct CommandResponse {
pub resp: u8,
pub return_op: u8,
pub return_length: u16,
pub value: u32,
pub status: u8,
pub error: u8,
}
impl Connection {
pub fn new(serial: impl SerialPort + 'static) -> Self {
Connection {
serial: Box::new(serial),
decoder: Decoder::new(1024),
}
}
pub fn reset_to_flash(&mut self) -> Result<(), Error> {
self.serial.set_dtr(false)?;
self.serial.set_rts(true)?;
sleep(Duration::from_millis(100));
self.serial.set_dtr(true)?;
self.serial.set_rts(false)?;
sleep(Duration::from_millis(50));
self.serial.set_dtr(true)?;
Ok(())
}
pub fn read_response(&mut self, timeout: u64) -> Result<Option<CommandResponse>, Error> {
let response = self.read(timeout)?;
if response.len() < 10 {
return Ok(None);
}
let header: CommandResponse = *from_bytes(&response[0..10]);
Ok(Some(header))
}
pub fn write_command(
&mut self,
command: u8,
data: impl LazyBytes<Box<dyn SerialPort>>,
check: u32,
) -> Result<(), Error> {
let mut encoder = SlipEncoder::new(&mut self.serial)?;
encoder.write(&[0])?;
encoder.write(&[command])?;
encoder.write(&(data.length().to_le_bytes()))?;
encoder.write(&(check.to_le_bytes()))?;
data.write(&mut encoder)?;
encoder.finish()?;
Ok(())
}
pub fn command<'a>(
&mut self,
command: u8,
data: impl LazyBytes<Box<dyn SerialPort>>,
check: u32,
timeout: u64,
) -> Result<CommandResponse, Error> {
self.write_command(command, data, check)?;
match self.read_response(timeout)? {
Some(response) if response.return_op == command as u8 => {
if response.status == 1 {
Err(Error::RomError(RomError::from(response.error)))
} else {
Ok(response)
}
}
_ => Err(Error::ConnectionFailed),
}
}
fn read(&mut self, timeout: u64) -> Result<Vec<u8>, Error> {
self.serial
.set_timeout(Duration::from_millis(timeout))
.unwrap();
Ok(self.decoder.decode(&mut self.serial)?)
}
pub fn flush(&mut self) -> Result<(), Error> {
self.serial.flush()?;
Ok(())
}
}
pub trait LazyBytes<W: Write> {
fn write(self, encoder: &mut SlipEncoder<W>) -> Result<(), Error>;
fn length(&self) -> u16;
}
impl<W: Write> LazyBytes<W> for &[u8] {
fn write(self, encoder: &mut SlipEncoder<W>) -> Result<(), Error> {
encoder.write(self)?;
Ok(())
}
fn length(&self) -> u16 {
self.len() as u16
}
}
impl<W: Write, F: Fn(&mut SlipEncoder<W>) -> Result<(), Error>> LazyBytes<W> for (u16, F) {
fn write(self, encoder: &mut SlipEncoder<W>) -> Result<(), Error> {
self.1(encoder)
}
fn length(&self) -> u16 {
self.0
}
}

View File

@@ -1,15 +1,12 @@
use crate::chip::{Chip, ESP8266};
use crate::connection::Connection;
use crate::elf::FirmwareImage;
use crate::encoder::SlipEncoder;
use crate::error::RomError;
use crate::Error;
use bytemuck::{bytes_of, from_bytes, Pod, Zeroable};
use bytemuck::{bytes_of, Pod, Zeroable};
use serial::SerialPort;
use slip_codec::Decoder;
use std::io::Write;
use std::mem::size_of;
use std::thread::sleep;
use std::time::Duration;
type Encoder<'a> = SlipEncoder<'a, Box<dyn SerialPort>>;
@@ -41,18 +38,6 @@ enum Command {
ReadReg = 0x0a,
}
#[derive(Debug, Zeroable, Pod, Copy, Clone)]
#[repr(C)]
#[repr(packed)]
struct CommandResponse {
resp: u8,
return_op: u8,
return_length: u16,
value: u32,
status: u8,
error: u8,
}
#[derive(Zeroable, Pod, Copy, Clone, Debug)]
#[repr(C)]
struct BlockParams {
@@ -79,89 +64,16 @@ struct EntryParams {
}
pub struct Flasher {
serial: Box<dyn SerialPort>,
decoder: Decoder,
connected: bool,
connection: Connection,
}
impl Flasher {
pub fn new(serial: impl SerialPort + 'static) -> Self {
Flasher {
serial: Box::new(serial),
decoder: Decoder::new(1024),
connected: false,
}
}
fn reset_to_flash(&mut self) -> Result<(), Error> {
self.serial.set_dtr(false)?;
self.serial.set_rts(true)?;
sleep(Duration::from_millis(100));
self.serial.set_dtr(true)?;
self.serial.set_rts(false)?;
sleep(Duration::from_millis(50));
self.serial.set_dtr(true)?;
Ok(())
}
fn read_response(&mut self, timeout: Timeouts) -> Result<Option<CommandResponse>, Error> {
let response = self.read(timeout)?;
if response.len() < 10 {
return Ok(None);
}
let header: CommandResponse = *from_bytes(&response[0..10]);
Ok(Some(header))
}
fn write_command(
&mut self,
command: Command,
data: impl CommandData<Box<dyn SerialPort>>,
check: u32,
) -> Result<(), Error> {
let mut encoder = SlipEncoder::new(&mut self.serial)?;
encoder.write(&[0])?;
encoder.write(&[command as u8])?;
encoder.write(&(data.length().to_le_bytes()))?;
encoder.write(&(check.to_le_bytes()))?;
data.write(&mut encoder)?;
encoder.finish()?;
Ok(())
}
fn command<'a>(
&mut self,
command: Command,
data: impl CommandData<Box<dyn SerialPort>>,
check: u32,
timeout: Timeouts,
) -> Result<CommandResponse, Error> {
self.write_command(command, data, check)?;
match self.read_response(timeout)? {
Some(response) if response.return_op == command as u8 => {
if response.status == 1 {
Err(Error::RomError(RomError::from(response.error)))
} else {
Ok(response)
}
}
_ => Err(Error::ConnectionFailed),
}
}
fn read(&mut self, timeout: Timeouts) -> Result<Vec<u8>, Error> {
self.serial
.set_timeout(Duration::from_millis(timeout as u64))
.unwrap();
Ok(self.decoder.decode(&mut self.serial)?)
pub fn connect(serial: impl SerialPort + 'static) -> Result<Self, Error> {
let mut flasher = Flasher {
connection: Connection::new(serial),
};
flasher.start_connection()?;
Ok(flasher)
}
fn sync(&mut self) -> Result<(), Error> {
@@ -171,10 +83,11 @@ impl Flasher {
0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55,
][..];
self.write_command(Command::Sync, data, 0)?;
self.connection
.write_command(Command::Sync as u8, data, 0)?;
for _ in 0..10 {
match self.read_response(Timeouts::Sync)? {
match self.connection.read_response(Timeouts::Sync as u64)? {
Some(response) if response.return_op == Command::Sync as u8 => {
if response.status == 1 {
return Err(Error::RomError(RomError::from(response.error)));
@@ -188,7 +101,7 @@ impl Flasher {
for _ in 0..7 {
loop {
match self.read_response(Timeouts::Sync)? {
match self.connection.read_response(Timeouts::Sync as u64)? {
Some(_) => break,
_ => continue,
}
@@ -198,18 +111,14 @@ impl Flasher {
Ok(())
}
pub fn connect(&mut self) -> Result<(), Error> {
if self.connected {
return Ok(());
}
self.reset_to_flash()?;
fn start_connection(&mut self) -> Result<(), Error> {
self.connection.reset_to_flash()?;
for _ in 0..10 {
self.serial.flush()?;
self.connection.flush()?;
if let Ok(_) = self.sync() {
return Ok(());
}
}
self.connected = true;
Err(Error::ConnectionFailed)
}
@@ -227,7 +136,12 @@ impl Flasher {
block_size,
offset,
};
self.command(command, bytes_of(&params), 0, Timeouts::Default)?;
self.connection.command(
command as u8,
bytes_of(&params),
0,
Timeouts::Default as u64,
)?;
Ok(())
}
@@ -248,8 +162,8 @@ impl Flasher {
let length = size_of::<BlockParams>() + data.len() + padding;
self.command(
command,
self.connection.command(
command as u8,
(length as u16, |encoder: &mut Encoder| {
encoder.write(bytes_of(&params))?;
encoder.write(&data)?;
@@ -258,7 +172,7 @@ impl Flasher {
Ok(())
}),
checksum(&data, CHECKSUM_INIT) as u32,
Timeouts::Default,
Timeouts::Default as u64,
)?;
Ok(())
}
@@ -268,12 +182,14 @@ impl Flasher {
no_entry: (entry == 0) as u32,
entry,
};
self.write_command(Command::MemEnd, bytes_of(&params), 0)?;
self.connection
.write_command(Command::MemEnd as u8, bytes_of(&params), 0)?;
Ok(())
}
fn flash_finish(&mut self, reboot: bool) -> Result<(), Error> {
self.write_command(Command::FlashEnd, &[(!reboot) as u8][..], 0)?;
self.connection
.write_command(Command::FlashEnd as u8, &[(!reboot) as u8][..], 0)?;
Ok(())
}
@@ -287,7 +203,7 @@ impl Flasher {
///
/// Note that this will not touch the flash on the device
pub fn load_elf_to_ram(&mut self, elf_data: &[u8]) -> Result<(), Error> {
self.connect()?;
self.start_connection()?;
let image = FirmwareImage::from_data(elf_data).map_err(|_| Error::InvalidElf)?;
if image.rom_segments().next().is_some() {
@@ -319,7 +235,7 @@ impl Flasher {
/// Load an elf image to flash and execute it
pub fn load_elf_to_flash(&mut self, elf_data: &[u8]) -> Result<(), Error> {
self.connect()?;
self.start_connection()?;
self.enable_flash()?;
let image = FirmwareImage::from_data(elf_data).map_err(|_| Error::InvalidElf)?;
@@ -373,30 +289,3 @@ pub fn checksum(data: &[u8], mut checksum: u8) -> u8 {
checksum
}
trait CommandData<W: Write> {
fn write(self, encoder: &mut SlipEncoder<W>) -> Result<(), Error>;
fn length(&self) -> u16;
}
impl<W: Write> CommandData<W> for &[u8] {
fn write(self, encoder: &mut SlipEncoder<W>) -> Result<(), Error> {
encoder.write(self)?;
Ok(())
}
fn length(&self) -> u16 {
self.len() as u16
}
}
impl<W: Write, F: Fn(&mut SlipEncoder<W>) -> Result<(), Error>> CommandData<W> for (u16, F) {
fn write(self, encoder: &mut SlipEncoder<W>) -> Result<(), Error> {
self.1(encoder)
}
fn length(&self) -> u16 {
self.0
}
}

View File

@@ -1,4 +1,5 @@
mod chip;
mod connection;
mod elf;
mod encoder;
mod error;

View File

@@ -1,4 +1,5 @@
mod chip;
mod connection;
mod elf;
mod encoder;
mod error;
@@ -41,7 +42,7 @@ fn main() -> Result<(), MainError> {
Ok(())
})?;
let mut flasher = Flasher::new(serial);
let mut flasher = Flasher::connect(serial)?;
let input_bytes = read(&input)?;