feat(ok): add correct handling of ok packets in MYSQL implementation (#3910)

* feat(ok): add correct handling of ok packet

* feat(ok): add unit tests
This commit is contained in:
JP 2025-06-30 20:31:55 -04:00 committed by GitHub
parent 69f9ff9180
commit 8ff14dc37c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 40 additions and 9 deletions

View File

@ -208,22 +208,27 @@ impl MySqlConnection {
loop {
let packet = self.inner.stream.recv_packet().await?;
if packet[0] == 0xfe && packet.len() < 9 {
let eof = packet.eof(self.inner.stream.capabilities)?;
self.inner.status_flags = eof.status;
if packet[0] == 0xfe {
let (rows_affected, last_insert_id, status) = if packet.len() < 9 {
// EOF packet
let eof = packet.eof(self.inner.stream.capabilities)?;
(0, 0, eof.status)
} else {
// OK packet
let ok = packet.ok()?;
(ok.affected_rows, ok.last_insert_id, ok.status)
};
self.inner.status_flags = status;
r#yield!(Either::Left(MySqlQueryResult {
rows_affected: 0,
last_insert_id: 0,
rows_affected,
last_insert_id,
}));
if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
// more result sets exist, continue to the next one
if status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
*self.inner.stream.waiting.front_mut().unwrap() = Waiting::Result;
break;
}
self.inner.stream.waiting.pop_front();
return Ok(());
}

View File

@ -50,3 +50,29 @@ fn test_decode_ok_packet() {
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert!(p.status.contains(Status::SERVER_SESSION_STATE_CHANGED));
}
#[test]
fn test_decode_ok_packet_with_info() {
// OK packet with 0xfe header and length >= 9 (with appended info)
const DATA: &[u8] = b"\xfe\x01\x00\x02\x00\x00\x00\x05\x09info data";
let p = OkPacket::decode(DATA.into()).unwrap();
assert_eq!(p.affected_rows, 1);
assert_eq!(p.last_insert_id, 0);
assert_eq!(p.warnings, 0);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
}
#[test]
fn test_decode_ok_packet_with_extended_info() {
// OK packet with 0xfe header, affected rows, last insert id, and extended info
const DATA: &[u8] = b"\xfe\x05\x64\x02\x00\x01\x00\x0e\x14extended information";
let p = OkPacket::decode(DATA.into()).unwrap();
assert_eq!(p.affected_rows, 5);
assert_eq!(p.last_insert_id, 100);
assert_eq!(p.warnings, 1);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
}