use std::collections::{HashMap, HashSet}; use std::fmt::Write; use futures_core::future::BoxFuture; use futures_util::{stream, StreamExt, TryStreamExt}; use crate::arguments::Arguments; use crate::cursor::Cursor; use crate::describe::{Column, Describe}; use crate::executor::{Execute, Executor, RefExecutor}; use crate::postgres::protocol::{ self, CommandComplete, Field, Message, ParameterDescription, ReadyForQuery, RowDescription, StatementId, TypeFormat, TypeId, }; use crate::postgres::type_info::SharedStr; use crate::postgres::types::try_resolve_type_name; use crate::postgres::{ PgArguments, PgConnection, PgCursor, PgQueryAs, PgRow, PgTypeInfo, Postgres, }; use crate::query_as::query_as; use crate::row::Row; impl PgConnection { pub(crate) fn write_simple_query(&mut self, query: &str) { self.stream.write(protocol::Query(query)); } pub(crate) async fn write_prepare( &mut self, query: &str, args: &PgArguments, ) -> crate::Result { if let Some(&id) = self.cache_statement_id.get(query) { Ok(id) } else { let id = StatementId(self.next_statement_id); self.next_statement_id += 1; // Build a list of type OIDs from the type info array provided by PgArguments // This may need to query Postgres for an OID of a user-defined type let mut types = Vec::with_capacity(args.types.len()); for ty in &args.types { types.push(if let Some(oid) = ty.id { oid.0 } else { self.get_type_id_by_name(&*ty.name).await? }); } self.stream.write(protocol::Parse { statement: id, param_types: &*types, query, }); Ok(id) } } pub(crate) fn write_describe(&mut self, d: protocol::Describe) { self.stream.write(d); } pub(crate) async fn write_bind( &mut self, portal: &str, statement: StatementId, args: &mut PgArguments, ) -> crate::Result<()> { args.buffer.patch_type_holes(self).await?; self.stream.write(protocol::Bind { portal, statement, formats: &[TypeFormat::Binary], values_len: args.types.len() as i16, values: &*args.buffer, result_formats: &[TypeFormat::Binary], }); Ok(()) } pub(crate) fn write_execute(&mut self, portal: &str, limit: i32) { self.stream.write(protocol::Execute { portal, limit }); } pub(crate) fn write_sync(&mut self) { self.stream.write(protocol::Sync); } async fn wait_until_ready(&mut self) -> crate::Result<()> { // depending on how the previous query finished we may need to continue // pulling messages from the stream until we receive a [ReadyForQuery] message // postgres sends the [ReadyForQuery] message when it's fully complete with processing // the previous query if !self.is_ready { loop { if let Message::ReadyForQuery = self.stream.receive().await? { // we are now ready to go self.is_ready = true; break; } } } Ok(()) } // Write out the query to the connection stream, ensure that we are synchronized at the // most recent [ReadyForQuery] and flush our buffer to postgres. // // It is safe to call this method repeatedly (but all data from postgres would be lost) but // it is assumed that a call to [PgConnection::affected_rows] or [PgCursor::next] would // immediately follow. pub(crate) async fn run( &mut self, query: &str, arguments: Option, ) -> crate::Result> { let statement = if let Some(mut arguments) = arguments { // Check the statement cache for a statement ID that matches the given query // If it doesn't exist, we generate a new statement ID and write out [Parse] to the // connection command buffer let statement = self.write_prepare(query, &arguments).await?; // Next, [Bind] attaches the arguments to the statement and creates a named portal self.write_bind("", statement, &mut arguments).await?; // Next, [Describe] will return the expected result columns and types // Conditionally run [Describe] only if the results have not been cached if !self.cache_statement.contains_key(&statement) { self.write_describe(protocol::Describe::Portal("")); } // Next, [Execute] then executes the named portal self.write_execute("", 0); // Finally, [Sync] asks postgres to process the messages that we sent and respond with // a [ReadyForQuery] message when it's completely done. Theoretically, we could send // dozens of queries before a [Sync] and postgres can handle that. Execution on the server // is still serial but it would reduce round-trips. Some kind of builder pattern that is // termed batching might suit this. self.write_sync(); Some(statement) } else { // https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.4 self.write_simple_query(query); None }; self.wait_until_ready().await?; self.stream.flush().await?; self.is_ready = false; // only cache if let Some(statement) = statement { // prefer redundant lookup to copying the query string if !self.cache_statement_id.contains_key(query) { // wait for `ParseComplete` on the stream or the // error before we cache the statement match self.stream.receive().await? { Message::ParseComplete => { self.cache_statement_id.insert(query.into(), statement); } message => { return Err(protocol_err!("run: unexpected message: {:?}", message).into()); } } } } Ok(statement) } async fn do_describe<'e, 'q: 'e>( &'e mut self, query: &'q str, ) -> crate::Result> { self.is_ready = false; let statement = self.write_prepare(query, &Default::default()).await?; self.write_describe(protocol::Describe::Statement(statement)); self.write_sync(); self.stream.flush().await?; let params = loop { match self.stream.receive().await? { Message::ParseComplete => {} Message::ParameterDescription => { break ParameterDescription::read(self.stream.buffer())?; } message => { return Err(protocol_err!( "expected ParameterDescription; received {:?}", message ) .into()); } }; }; let result = match self.stream.receive().await? { Message::NoData => None, Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?), message => { return Err(protocol_err!( "expected RowDescription or NoData; received {:?}", message ) .into()); } }; self.wait_until_ready().await?; let result_fields = result.map_or_else(Default::default, |r| r.fields); let type_names = self .get_type_names( params .ids .iter() .cloned() .chain(result_fields.iter().map(|field| field.type_id)), ) .await?; Ok(Describe { param_types: params .ids .iter() .map(|id| Some(PgTypeInfo::new(*id, &type_names[&id.0]))) .collect::>() .into_boxed_slice(), result_columns: self .map_result_columns(result_fields, type_names) .await? .into_boxed_slice(), }) } pub(crate) async fn get_type_id_by_name(&mut self, name: &str) -> crate::Result { if let Some(oid) = self.cache_type_oid.get(name) { return Ok(*oid); } // language=SQL let (oid,): (u32,) = query_as( " SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 ", ) .bind(name) .fetch_one(&mut *self) .await?; let shared = SharedStr::from(name.to_owned()); self.cache_type_oid.insert(shared.clone(), oid); self.cache_type_name.insert(oid, shared.clone()); Ok(oid) } pub(crate) fn get_type_info_by_oid(&mut self, oid: u32) -> PgTypeInfo { if let Some(name) = try_resolve_type_name(oid) { return PgTypeInfo::new(TypeId(oid), name); } if let Some(name) = self.cache_type_name.get(&oid) { return PgTypeInfo::new(TypeId(oid), name); } // NOTE: The name isn't too important for the decode lifecycle return PgTypeInfo::new(TypeId(oid), ""); } async fn get_type_names( &mut self, ids: impl IntoIterator, ) -> crate::Result> { let type_ids: HashSet = ids.into_iter().map(|id| id.0).collect::>(); if type_ids.is_empty() { return Ok(HashMap::new()); } // uppercase type names are easier to visually identify let mut query = "select types.type_id, UPPER(pg_type.typname) from (VALUES ".to_string(); let mut args = PgArguments::default(); let mut pushed = false; // TODO: dedup this with the one below, ideally as an API we can export for (i, (&type_id, bind)) in type_ids.iter().zip((1..).step_by(2)).enumerate() { if pushed { query += ", "; } pushed = true; let _ = write!(query, "(${}, ${})", bind, bind + 1); // not used in the output but ensures are values are sorted correctly args.add(i as i32); args.add(type_id as i32); } query += ") as types(idx, type_id) \ inner join pg_catalog.pg_type on pg_type.oid = type_id \ order by types.idx"; crate::query::query(&query) .bind_all(args) .try_map(|row: PgRow| -> crate::Result<(u32, SharedStr)> { Ok(( row.try_get::(0)? as u32, row.try_get::(1)?.into(), )) }) .fetch(self) .try_collect() .await } async fn map_result_columns( &mut self, fields: Box<[Field]>, type_names: HashMap, ) -> crate::Result>> { if fields.is_empty() { return Ok(vec![]); } let mut query = "select col.idx, pg_attribute.attnotnull from (VALUES ".to_string(); let mut pushed = false; let mut args = PgArguments::default(); for (i, (field, bind)) in fields.iter().zip((1..).step_by(3)).enumerate() { if pushed { query += ", "; } pushed = true; let _ = write!( query, "(${}::int4, ${}::int4, ${}::int2)", bind, bind + 1, bind + 2 ); args.add(i as i32); args.add(field.table_id.map(|id| id as i32)); args.add(field.column_id); } query += ") as col(idx, table_id, col_idx) \ left join pg_catalog.pg_attribute on table_id is not null and attrelid = table_id and attnum = col_idx \ order by col.idx;"; log::trace!("describe pg_attribute query: {:#?}", query); crate::query::query(&query) .bind_all(args) .try_map(|row: PgRow| { let idx = row.try_get::(0)?; let non_null = row.try_get::, _>(1)?; Ok((idx, non_null)) }) .fetch(self) .zip(stream::iter(fields.into_vec().into_iter().enumerate())) .map(|(row, (fidx, field))| -> crate::Result> { let (idx, non_null) = row?; if idx != fidx as i32 { return Err( protocol_err!("missing field from query, field: {:?}", field).into(), ); } Ok(Column { name: field.name, table_id: field.table_id, type_info: Some(PgTypeInfo::new( field.type_id, &type_names[&field.type_id.0], )), non_null, }) }) .try_collect() .await } // Poll messages from Postgres, counting the rows affected, until we finish the query // This must be called directly after a call to [PgConnection::execute] async fn affected_rows(&mut self) -> crate::Result { let mut rows = 0; loop { match self.stream.receive().await? { Message::ParseComplete | Message::BindComplete | Message::NoData | Message::EmptyQueryResponse | Message::RowDescription => {} Message::DataRow => { // TODO: should we log a warning? this is almost // definitely a programmer error } Message::CommandComplete => { rows += CommandComplete::read(self.stream.buffer())?.affected_rows; } Message::ReadyForQuery => { // TODO: How should we handle an ERROR status form ReadyForQuery let _ready = ReadyForQuery::read(self.stream.buffer())?; self.is_ready = true; break; } message => { return Err( protocol_err!("affected_rows: unexpected message: {:?}", message).into(), ); } } } Ok(rows) } } impl Executor for super::PgConnection { type Database = Postgres; fn execute<'e, 'q: 'e, 'c: 'e, E: 'e>( &'c mut self, query: E, ) -> BoxFuture<'e, crate::Result> where E: Execute<'q, Self::Database>, { Box::pin(async move { let (query, arguments) = query.into_parts(); self.run(query, arguments).await?; self.affected_rows().await }) } fn fetch<'q, E>(&mut self, query: E) -> PgCursor<'_, 'q> where E: Execute<'q, Self::Database>, { PgCursor::from_connection(self, query) } #[doc(hidden)] fn describe<'e, 'q, E: 'e>( &'e mut self, query: E, ) -> BoxFuture<'e, crate::Result>> where E: Execute<'q, Self::Database>, { Box::pin(async move { self.do_describe(query.into_parts().0).await }) } } impl<'c> RefExecutor<'c> for &'c mut super::PgConnection { type Database = Postgres; fn fetch_by_ref<'q, E>(self, query: E) -> PgCursor<'c, 'q> where E: Execute<'q, Self::Database>, { PgCursor::from_connection(self, query) } }