From d3460fd4795ef88041f6ef11fcdabe2a7f645709 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Tue, 16 Jul 2019 12:51:54 -0700 Subject: [PATCH] Fix command complete row parsing for row-less commands --- .../src/command_complete.rs | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/sqlx-postgres-protocol/src/command_complete.rs b/sqlx-postgres-protocol/src/command_complete.rs index 11d2d352..b750e0cd 100644 --- a/sqlx-postgres-protocol/src/command_complete.rs +++ b/sqlx-postgres-protocol/src/command_complete.rs @@ -1,6 +1,6 @@ use crate::Decode; use bytes::Bytes; -use memchr::{memchr, memrchr}; +use memchr::memrchr; use std::{io, str}; #[derive(Debug)] @@ -9,14 +9,20 @@ pub struct CommandComplete { } impl CommandComplete { + #[inline] pub fn tag(&self) -> &str { - unsafe { str::from_utf8_unchecked(self.tag.as_ref()) } + unsafe { str::from_utf8_unchecked(&self.tag.as_ref()[..self.tag.len() - 1]) } } pub fn rows(&self) -> u64 { - let rows_start = memrchr(b' ', &*self.tag).map_or(0, |i| i + 1); - let rows_s = - unsafe { str::from_utf8_unchecked(&self.tag[rows_start..(self.tag.len() - 1)]) }; + // Attempt to parse the last word in the command tag as an integer + // If it can't be parased, the tag is probably "CREATE TABLE" or something + // and we should return 0 rows + + let rows_start = memrchr(b' ', &*self.tag).unwrap(); + let rows_s = unsafe { + str::from_utf8_unchecked(&self.tag.as_ref()[(rows_start + 1)..(self.tag.len() - 1)]) + }; rows_s.parse().unwrap_or(0) } @@ -35,7 +41,8 @@ mod tests { use bytes::Bytes; use std::io; - const COMMAND_COMPLETE_INSERT: &[u8] = b"INSERT 0 512\0"; + const COMMAND_COMPLETE_INSERT: &[u8] = b"INSERT 0 1\0"; + const COMMAND_COMPLETE_UPDATE: &[u8] = b"UPDATE 512\0"; const COMMAND_COMPLETE_CREATE_TABLE: &[u8] = b"CREATE TABLE\0"; #[test] @@ -43,7 +50,18 @@ mod tests { let src = Bytes::from_static(COMMAND_COMPLETE_INSERT); let message = CommandComplete::decode(src)?; - assert_eq!(message.tag(), "INSERT 0 512"); + assert_eq!(message.tag(), "INSERT 0 1"); + assert_eq!(message.rows(), 1); + + Ok(()) + } + + #[test] + fn it_decodes_command_complete_for_update() -> io::Result<()> { + let src = Bytes::from_static(COMMAND_COMPLETE_UPDATE); + let message = CommandComplete::decode(src)?; + + assert_eq!(message.tag(), "UPDATE 512"); assert_eq!(message.rows(), 512); Ok(()) @@ -51,7 +69,7 @@ mod tests { #[test] fn it_decodes_command_complete_for_create_table() -> io::Result<()> { - let src = Bytes::from_static(COMMAND_COMPLETE_INSERT); + let src = Bytes::from_static(COMMAND_COMPLETE_CREATE_TABLE); let message = CommandComplete::decode(src)?; assert_eq!(message.tag(), "CREATE TABLE");