use std::collections::HashMap; use std::sync::Arc; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use crate::describe::{Column, Describe}; use crate::executor::Executor; use crate::mysql::error::MySqlError; use crate::mysql::protocol::{ Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComSetOption, ComStmtExecute, ComStmtPrepare, ComStmtPrepareOk, Cursor, Decode, EofPacket, ErrPacket, OkPacket, Row, SetOption, Type, }; use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow}; enum Step { Command(u64), Row(Row), } enum OkOrResultSet { Ok(OkPacket), ResultSet(ColumnCount), } impl MySqlConnection { async fn ignore_columns(&mut self, count: usize) -> crate::Result<()> { for _ in 0..count { let _column = ColumnDefinition::decode(self.receive().await?)?; } if count > 0 { self.receive_eof().await?; } Ok(()) } async fn receive_ok_or_column_count(&mut self) -> crate::Result { let packet = self.receive().await?; match packet[0] { 0xfe if packet.len() < 0xffffff => { let ok = OkPacket::decode(packet)?; self.ready = true; Ok(OkOrResultSet::Ok(ok)) } 0x00 => { let ok = OkPacket::decode(packet)?; self.ready = true; Ok(OkOrResultSet::Ok(ok)) } 0xff => { let err = ErrPacket::decode(packet)?; self.ready = true; Err(MySqlError(err).into()) } _ => { let cc = ColumnCount::decode(packet)?; Ok(OkOrResultSet::ResultSet(cc)) } } } async fn receive_column_types(&mut self, count: usize) -> crate::Result> { let mut columns: Vec = Vec::with_capacity(count); for _ in 0..count { let packet = self.receive().await?; let column: ColumnDefinition = ColumnDefinition::decode(packet)?; columns.push(column.r#type); } if count > 0 { self.receive_eof().await?; } Ok(columns.into_boxed_slice()) } async fn wait_for_ready(&mut self) -> crate::Result<()> { if !self.ready { while let Some(_step) = self.step(&[], true).await? { // Drain steps until we hit the end } } Ok(()) } async fn prepare(&mut self, query: &str) -> crate::Result { // Start by sending a COM_STMT_PREPARE self.begin_command_phase(); self.write(ComStmtPrepare { query }); self.stream.flush().await?; // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html // First we should receive a COM_STMT_PREPARE_OK let packet = self.receive().await?; if packet[0] == 0xff { // Oops, there was an error in the prepare command return Err(MySqlError(ErrPacket::decode(packet)?).into()); } ComStmtPrepareOk::decode(packet) } async fn prepare_with_cache(&mut self, query: &str) -> crate::Result { if let Some(&id) = self.statement_cache.get(query) { Ok(id) } else { let prepare_ok = self.prepare(query).await?; // Remember our statement ID, so we do'd do this again the next time self.statement_cache .put(query.to_owned(), prepare_ok.statement_id); // Ignore input parameters self.ignore_columns(prepare_ok.params as usize).await?; // Collect output parameter names let mut columns = HashMap::with_capacity(prepare_ok.columns as usize); let mut index = 0_usize; for _ in 0..prepare_ok.columns { let column = ColumnDefinition::decode(self.receive().await?)?; if let Some(name) = column.column_alias.or(column.column) { columns.insert(name, index); } index += 1; } if prepare_ok.columns > 0 { self.receive_eof().await?; } // Remember our column map in the statement cache self.statement_cache .put_columns(prepare_ok.statement_id, columns); Ok(prepare_ok.statement_id) } } // [COM_STMT_EXECUTE] async fn execute_statement(&mut self, id: u32, args: MySqlArguments) -> crate::Result<()> { self.begin_command_phase(); self.ready = false; self.write(ComStmtExecute { cursor: Cursor::NO_CURSOR, statement_id: id, params: &args.params, null_bitmap: &args.null_bitmap, param_types: &args.param_types, }); self.stream.flush().await?; Ok(()) } async fn step(&mut self, columns: &[Type], binary: bool) -> crate::Result> { let capabilities = self.capabilities; let packet = ret_if_none!(self.try_receive().await?); match packet[0] { 0xfe if packet.len() < 0xffffff => { // Resultset row can begin with 0xfe byte (when using text protocol // with a field length > 0xffffff) if !capabilities.contains(Capabilities::DEPRECATE_EOF) { let _eof = EofPacket::decode(packet)?; self.ready = true; return Ok(None); } else { let ok = OkPacket::decode(packet)?; self.ready = true; return Ok(Some(Step::Command(ok.affected_rows))); } } 0xff => { let err = ErrPacket::decode(packet)?; self.ready = true; return Err(MySqlError(err).into()); } _ => { return Ok(Some(Step::Row(Row::decode(packet, columns, binary)?))); } } } } impl MySqlConnection { async fn send(&mut self, query: &str) -> crate::Result<()> { self.wait_for_ready().await?; self.begin_command_phase(); self.ready = false; // enable multi-statement only for this query self.write(ComQuery { query }); self.stream.flush().await?; // COM_QUERY can terminate before the result set with an ERR or OK packet let num_columns = match self.receive_ok_or_column_count().await? { OkOrResultSet::Ok(_) => { return Ok(()); } OkOrResultSet::ResultSet(cc) => cc.columns as usize, }; let columns = self.receive_column_types(num_columns as usize).await?; while let Some(step) = self.step(&columns, false).await? { // Drop all responses } Ok(()) } async fn execute(&mut self, query: &str, args: MySqlArguments) -> crate::Result { self.wait_for_ready().await?; let statement_id = self.prepare_with_cache(query).await?; self.execute_statement(statement_id, args).await?; // COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet let num_columns = match self.receive_ok_or_column_count().await? { OkOrResultSet::Ok(ok) => { return Ok(ok.affected_rows); } OkOrResultSet::ResultSet(cc) => cc.columns as usize, }; self.ignore_columns(num_columns).await?; let mut res = 0; while let Some(step) = self.step(&[], true).await? { if let Step::Command(affected) = step { res = affected; } } Ok(res) } async fn describe(&mut self, query: &str) -> crate::Result> { self.wait_for_ready().await?; let prepare_ok = self.prepare(query).await?; let mut param_types = Vec::with_capacity(prepare_ok.params as usize); let mut result_columns = Vec::with_capacity(prepare_ok.columns as usize); for _ in 0..prepare_ok.params { let param = ColumnDefinition::decode(self.receive().await?)?; param_types.push(param.r#type.0); } if prepare_ok.params > 0 { self.receive_eof().await?; } for _ in 0..prepare_ok.columns { let column = ColumnDefinition::decode(self.receive().await?)?; result_columns.push(Column:: { name: column.column_alias.or(column.column), table_id: column.table_alias.or(column.table), type_id: column.r#type.0, _non_exhaustive: (), }); } if prepare_ok.columns > 0 { self.receive_eof().await?; } Ok(Describe { param_types: param_types.into_boxed_slice(), result_columns: result_columns.into_boxed_slice(), _non_exhaustive: (), }) } fn fetch<'e, 'q: 'e>( &'e mut self, query: &'q str, args: MySqlArguments, ) -> BoxStream<'e, crate::Result> { Box::pin(async_stream::try_stream! { self.wait_for_ready().await?; let statement_id = self.prepare_with_cache(query).await?; let columns = self.statement_cache.get_columns(statement_id); self.execute_statement(statement_id, args).await?; // COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet let num_columns = match self.receive_ok_or_column_count().await? { OkOrResultSet::Ok(_) => { return; } OkOrResultSet::ResultSet(cc) => { cc.columns as usize } }; let column_types = self.receive_column_types(num_columns).await?; while let Some(Step::Row(row)) = self.step(&column_types, true).await? { yield MySqlRow { row, columns: Arc::clone(&columns) }; } }) } } impl Executor for MySqlConnection { type Database = super::MySql; fn send<'e, 'q: 'e>(&'e mut self, query: &'q str) -> BoxFuture<'e, crate::Result<()>> { Box::pin(self.send(query)) } fn execute<'e, 'q: 'e>( &'e mut self, query: &'q str, args: MySqlArguments, ) -> BoxFuture<'e, crate::Result> { Box::pin(self.execute(query, args)) } fn fetch<'e, 'q: 'e>( &'e mut self, query: &'q str, args: MySqlArguments, ) -> BoxStream<'e, crate::Result> { self.fetch(query, args) } fn describe<'e, 'q: 'e>( &'e mut self, query: &'q str, ) -> BoxFuture<'e, crate::Result>> { Box::pin(self.describe(query)) } }