use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; use std::sync::Arc; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::postgres::message::{ self, Bind, Close, CommandComplete, DataRow, Flush, MessageFormat, ParameterDescription, Parse, Query, RowDescription, }; use crate::postgres::type_info::PgType; use crate::postgres::{PgArguments, PgConnection, PgDone, PgRow, PgValueFormat, Postgres}; use crate::statement::StatementInfo; async fn prepare( conn: &mut PgConnection, query: &str, arguments: &PgArguments, ) -> Result { let id = conn.next_statement_id; conn.next_statement_id = conn.next_statement_id.wrapping_add(1); // build a list of type OIDs to send to the database in the PARSE command // we have not yet started the query sequence, so we are *safe* to cleanly make // additional queries here to get any missing OIDs let mut param_types = Vec::with_capacity(arguments.types.len()); let mut has_fetched = false; for ty in &arguments.types { param_types.push(if let PgType::DeclareWithName(name) = &ty.0 { has_fetched = true; conn.fetch_type_id_by_name(name).await? } else { ty.0.oid() }); } // flush and wait until we are re-ready if has_fetched { conn.wait_until_ready().await?; } // next we send the PARSE command to the server conn.stream.write(Parse { param_types: &*param_types, query, statement: id, }); // we ask for the server to immediately send us the result of the PARSE command by using FLUSH conn.stream.write(Flush); conn.stream.flush().await?; // indicates that the SQL query string is now successfully parsed and has semantic validity let _: () = conn .stream .recv_expect(MessageFormat::ParseComplete) .await?; Ok(id) } async fn recv_desc_params(conn: &mut PgConnection) -> Result { conn.stream .recv_expect(MessageFormat::ParameterDescription) .await } async fn recv_desc_rows(conn: &mut PgConnection) -> Result, Error> { let rows: Option = match conn.stream.recv().await? { // describes the rows that will be returned when the statement is eventually executed message if message.format == MessageFormat::RowDescription => Some(message.decode()?), // no data would be returned if this statement was executed message if message.format == MessageFormat::NoData => None, message => { return Err(err_protocol!( "expecting RowDescription or NoData but received {:?}", message.format )); } }; Ok(rows) } impl PgConnection { // wait for CloseComplete to indicate a statement was closed pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> { // we need to wait for the [CloseComplete] to be returned from the server while count > 0 { match self.stream.recv().await? { message if message.format == MessageFormat::PortalSuspended => { // there was an open portal // this can happen if the last time a statement was used it was not fully executed // such as in [fetch_one] } message if message.format == MessageFormat::CloseComplete => { // successfully closed the statement (and freed up the server resources) count -= 1; } message => { return Err(err_protocol!( "expecting PortalSuspended or CloseComplete but received {:?}", message.format )); } } } Ok(()) } async fn prepare(&mut self, query: &str, arguments: &PgArguments) -> Result { if let Some(statement) = self.cache_statement.get_mut(query) { return Ok(*statement); } let statement = prepare(self, query, arguments).await?; if let Some(statement) = self.cache_statement.insert(query, statement) { self.stream.write(Close::Statement(statement)); self.stream.write(Flush); self.stream.flush().await?; self.wait_for_close_complete(1).await?; } Ok(statement) } async fn run( &mut self, query: &str, arguments: Option, limit: u8, ) -> Result, Error>> + '_, Error> { // before we continue, wait until we are "ready" to accept more queries self.wait_until_ready().await?; let format = if let Some(mut arguments) = arguments { // prepare the statement if this our first time executing it // always return the statement ID here let statement = self.prepare(query, &arguments).await?; // patch holes created during encoding arguments.buffer.patch_type_holes(self).await?; // describe the statement and, again, ask the server to immediately respond // we need to fully realize the types self.stream.write(message::Describe::Statement(statement)); self.stream.write(message::Flush); self.stream.flush().await?; let _ = recv_desc_params(self).await?; let rows = recv_desc_rows(self).await?; self.handle_row_description(rows, true).await?; self.wait_until_ready().await?; // bind to attach the arguments to the statement and create a portal self.stream.write(Bind { portal: None, statement, formats: &[PgValueFormat::Binary], num_params: arguments.types.len() as i16, params: &*arguments.buffer, result_formats: &[PgValueFormat::Binary], }); // executes the portal up to the passed limit // the protocol-level limit acts nearly identically to the `LIMIT` in SQL self.stream.write(message::Execute { portal: None, limit: limit.into(), }); // finally, [Sync] asks postgres to process the messages that we sent and respond with // a [ReadyForQuery] message when it's completely done. Theoretically, we could send // dozens of queries before a [Sync] and postgres can handle that. Execution on the server // is still serial but it would reduce round-trips. Some kind of builder pattern that is // termed batching might suit this. self.stream.write(message::Sync); // prepared statements are binary PgValueFormat::Binary } else { self.stream.write(Query(query)); // and unprepared statements are text PgValueFormat::Text }; // [Query] or [Sync] will trigger a [ReadyForQuery] self.pending_ready_for_query_count += 1; self.stream.flush().await?; Ok(try_stream! { loop { let message = self.stream.recv().await?; match message.format { MessageFormat::BindComplete | MessageFormat::ParseComplete | MessageFormat::ParameterDescription | MessageFormat::NoData => { // harmless messages to ignore } MessageFormat::CommandComplete => { // a SQL command completed normally let cc: CommandComplete = message.decode()?; r#yield!(Either::Left(PgDone { rows_affected: cc.rows_affected(), })); } MessageFormat::EmptyQueryResponse => { // empty query string passed to an unprepared execute } MessageFormat::RowDescription => { // indicates that a *new* set of rows are about to be returned self .handle_row_description(Some(message.decode()?), false) .await?; } MessageFormat::DataRow => { // one of the set of rows returned by a SELECT, FETCH, etc query let data: DataRow = message.decode()?; let row = PgRow { data, format, columns: Arc::clone(&self.scratch_row_columns), column_names: Arc::clone(&self.scratch_row_column_names), }; r#yield!(Either::Right(row)); } MessageFormat::ReadyForQuery => { // processing of the query string is complete self.handle_ready_for_query(message)?; break; } _ => { Err(err_protocol!( "execute: unexpected message: {:?}", message.format ))?; } } } Ok(()) }) } } impl<'c> Executor<'c> for &'c mut PgConnection { type Database = Postgres; fn fetch_many<'e, 'q: 'e, E: 'q>( self, mut query: E, ) -> BoxStream<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { let s = query.query(); let arguments = query.take_arguments(); Box::pin(try_stream! { let s = self.run(s, arguments, 0).await?; pin_mut!(s); while let Some(v) = s.try_next().await? { r#yield!(v); } Ok(()) }) } fn fetch_optional<'e, 'q: 'e, E: 'q>( self, mut query: E, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { let s = query.query(); let arguments = query.take_arguments(); Box::pin(async move { let s = self.run(s, arguments, 1).await?; pin_mut!(s); while let Some(s) = s.try_next().await? { if let Either::Right(r) = s { return Ok(Some(r)); } } Ok(None) }) } #[doc(hidden)] fn describe<'e, 'q: 'e, E: 'q>( self, query: E, ) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, { let s = query.query(); Box::pin(async move { let id = prepare(self, s, &Default::default()).await?; self.stream.write(message::Describe::Statement(id)); self.stream.write(Flush); self.stream.flush().await?; let params = recv_desc_params(self).await?; let rows = recv_desc_rows(self).await?; let params = self.handle_parameter_description(params).await?; self.handle_row_description(rows, true).await?; let columns = (&*self.scratch_row_columns).clone(); let nullable = self.get_nullable_for_columns(&columns).await?; Ok(StatementInfo { columns, nullable, parameters: Some(Either::Left(params)), }) }) } }