mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-10 11:14:45 +00:00
359 lines
12 KiB
Rust
359 lines
12 KiB
Rust
use either::Either;
|
|
use futures_core::future::BoxFuture;
|
|
use futures_core::stream::BoxStream;
|
|
use futures_core::Stream;
|
|
use futures_util::{pin_mut, TryStreamExt};
|
|
use std::sync::Arc;
|
|
|
|
use crate::error::Error;
|
|
use crate::executor::{Execute, Executor};
|
|
use crate::postgres::message::{
|
|
self, Bind, Close, CommandComplete, DataRow, Flush, MessageFormat, ParameterDescription, Parse,
|
|
Query, RowDescription,
|
|
};
|
|
use crate::postgres::type_info::PgType;
|
|
use crate::postgres::{PgArguments, PgConnection, PgDone, PgRow, PgValueFormat, Postgres};
|
|
use crate::statement::StatementInfo;
|
|
|
|
async fn prepare(
|
|
conn: &mut PgConnection,
|
|
query: &str,
|
|
arguments: &PgArguments,
|
|
) -> Result<u32, Error> {
|
|
let id = conn.next_statement_id;
|
|
conn.next_statement_id = conn.next_statement_id.wrapping_add(1);
|
|
|
|
// build a list of type OIDs to send to the database in the PARSE command
|
|
// we have not yet started the query sequence, so we are *safe* to cleanly make
|
|
// additional queries here to get any missing OIDs
|
|
|
|
let mut param_types = Vec::with_capacity(arguments.types.len());
|
|
let mut has_fetched = false;
|
|
|
|
for ty in &arguments.types {
|
|
param_types.push(if let PgType::DeclareWithName(name) = &ty.0 {
|
|
has_fetched = true;
|
|
conn.fetch_type_id_by_name(name).await?
|
|
} else {
|
|
ty.0.oid()
|
|
});
|
|
}
|
|
|
|
// flush and wait until we are re-ready
|
|
if has_fetched {
|
|
conn.wait_until_ready().await?;
|
|
}
|
|
|
|
// next we send the PARSE command to the server
|
|
conn.stream.write(Parse {
|
|
param_types: &*param_types,
|
|
query,
|
|
statement: id,
|
|
});
|
|
|
|
// we ask for the server to immediately send us the result of the PARSE command by using FLUSH
|
|
conn.stream.write(Flush);
|
|
conn.stream.flush().await?;
|
|
|
|
// indicates that the SQL query string is now successfully parsed and has semantic validity
|
|
let _: () = conn
|
|
.stream
|
|
.recv_expect(MessageFormat::ParseComplete)
|
|
.await?;
|
|
|
|
Ok(id)
|
|
}
|
|
|
|
async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
|
|
conn.stream
|
|
.recv_expect(MessageFormat::ParameterDescription)
|
|
.await
|
|
}
|
|
|
|
async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
|
|
let rows: Option<RowDescription> = match conn.stream.recv().await? {
|
|
// describes the rows that will be returned when the statement is eventually executed
|
|
message if message.format == MessageFormat::RowDescription => Some(message.decode()?),
|
|
|
|
// no data would be returned if this statement was executed
|
|
message if message.format == MessageFormat::NoData => None,
|
|
|
|
message => {
|
|
return Err(err_protocol!(
|
|
"expecting RowDescription or NoData but received {:?}",
|
|
message.format
|
|
));
|
|
}
|
|
};
|
|
|
|
Ok(rows)
|
|
}
|
|
|
|
impl PgConnection {
|
|
// wait for CloseComplete to indicate a statement was closed
|
|
pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
|
|
// we need to wait for the [CloseComplete] to be returned from the server
|
|
while count > 0 {
|
|
match self.stream.recv().await? {
|
|
message if message.format == MessageFormat::PortalSuspended => {
|
|
// there was an open portal
|
|
// this can happen if the last time a statement was used it was not fully executed
|
|
// such as in [fetch_one]
|
|
}
|
|
|
|
message if message.format == MessageFormat::CloseComplete => {
|
|
// successfully closed the statement (and freed up the server resources)
|
|
count -= 1;
|
|
}
|
|
|
|
message => {
|
|
return Err(err_protocol!(
|
|
"expecting PortalSuspended or CloseComplete but received {:?}",
|
|
message.format
|
|
));
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn prepare(&mut self, query: &str, arguments: &PgArguments) -> Result<u32, Error> {
|
|
if let Some(statement) = self.cache_statement.get_mut(query) {
|
|
return Ok(*statement);
|
|
}
|
|
|
|
let statement = prepare(self, query, arguments).await?;
|
|
|
|
if let Some(statement) = self.cache_statement.insert(query, statement) {
|
|
self.stream.write(Close::Statement(statement));
|
|
self.stream.write(Flush);
|
|
|
|
self.stream.flush().await?;
|
|
|
|
self.wait_for_close_complete(1).await?;
|
|
}
|
|
|
|
Ok(statement)
|
|
}
|
|
|
|
async fn run(
|
|
&mut self,
|
|
query: &str,
|
|
arguments: Option<PgArguments>,
|
|
limit: u8,
|
|
) -> Result<impl Stream<Item = Result<Either<PgDone, PgRow>, Error>> + '_, Error> {
|
|
// before we continue, wait until we are "ready" to accept more queries
|
|
self.wait_until_ready().await?;
|
|
|
|
let format = if let Some(mut arguments) = arguments {
|
|
// prepare the statement if this our first time executing it
|
|
// always return the statement ID here
|
|
let statement = self.prepare(query, &arguments).await?;
|
|
|
|
// patch holes created during encoding
|
|
arguments.buffer.patch_type_holes(self).await?;
|
|
|
|
// describe the statement and, again, ask the server to immediately respond
|
|
// we need to fully realize the types
|
|
self.stream.write(message::Describe::Statement(statement));
|
|
self.stream.write(message::Flush);
|
|
self.stream.flush().await?;
|
|
|
|
let _ = recv_desc_params(self).await?;
|
|
let rows = recv_desc_rows(self).await?;
|
|
|
|
self.handle_row_description(rows, true).await?;
|
|
self.wait_until_ready().await?;
|
|
|
|
// bind to attach the arguments to the statement and create a portal
|
|
self.stream.write(Bind {
|
|
portal: None,
|
|
statement,
|
|
formats: &[PgValueFormat::Binary],
|
|
num_params: arguments.types.len() as i16,
|
|
params: &*arguments.buffer,
|
|
result_formats: &[PgValueFormat::Binary],
|
|
});
|
|
|
|
// executes the portal up to the passed limit
|
|
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
|
|
self.stream.write(message::Execute {
|
|
portal: None,
|
|
limit: limit.into(),
|
|
});
|
|
|
|
// finally, [Sync] asks postgres to process the messages that we sent and respond with
|
|
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
|
|
// dozens of queries before a [Sync] and postgres can handle that. Execution on the server
|
|
// is still serial but it would reduce round-trips. Some kind of builder pattern that is
|
|
// termed batching might suit this.
|
|
self.stream.write(message::Sync);
|
|
|
|
// prepared statements are binary
|
|
PgValueFormat::Binary
|
|
} else {
|
|
self.stream.write(Query(query));
|
|
|
|
// and unprepared statements are text
|
|
PgValueFormat::Text
|
|
};
|
|
|
|
// [Query] or [Sync] will trigger a [ReadyForQuery]
|
|
self.pending_ready_for_query_count += 1;
|
|
self.stream.flush().await?;
|
|
|
|
Ok(try_stream! {
|
|
loop {
|
|
let message = self.stream.recv().await?;
|
|
|
|
match message.format {
|
|
MessageFormat::BindComplete
|
|
| MessageFormat::ParseComplete
|
|
| MessageFormat::ParameterDescription
|
|
| MessageFormat::NoData => {
|
|
// harmless messages to ignore
|
|
}
|
|
|
|
MessageFormat::CommandComplete => {
|
|
// a SQL command completed normally
|
|
let cc: CommandComplete = message.decode()?;
|
|
|
|
r#yield!(Either::Left(PgDone {
|
|
rows_affected: cc.rows_affected(),
|
|
}));
|
|
}
|
|
|
|
MessageFormat::EmptyQueryResponse => {
|
|
// empty query string passed to an unprepared execute
|
|
}
|
|
|
|
MessageFormat::RowDescription => {
|
|
// indicates that a *new* set of rows are about to be returned
|
|
self
|
|
.handle_row_description(Some(message.decode()?), false)
|
|
.await?;
|
|
}
|
|
|
|
MessageFormat::DataRow => {
|
|
// one of the set of rows returned by a SELECT, FETCH, etc query
|
|
let data: DataRow = message.decode()?;
|
|
let row = PgRow {
|
|
data,
|
|
format,
|
|
columns: Arc::clone(&self.scratch_row_columns),
|
|
column_names: Arc::clone(&self.scratch_row_column_names),
|
|
};
|
|
|
|
r#yield!(Either::Right(row));
|
|
}
|
|
|
|
MessageFormat::ReadyForQuery => {
|
|
// processing of the query string is complete
|
|
self.handle_ready_for_query(message)?;
|
|
break;
|
|
}
|
|
|
|
_ => {
|
|
Err(err_protocol!(
|
|
"execute: unexpected message: {:?}",
|
|
message.format
|
|
))?;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
})
|
|
}
|
|
}
|
|
|
|
impl<'c> Executor<'c> for &'c mut PgConnection {
|
|
type Database = Postgres;
|
|
|
|
fn fetch_many<'e, 'q: 'e, E: 'q>(
|
|
self,
|
|
mut query: E,
|
|
) -> BoxStream<'e, Result<Either<PgDone, PgRow>, Error>>
|
|
where
|
|
'c: 'e,
|
|
E: Execute<'q, Self::Database>,
|
|
{
|
|
let s = query.query();
|
|
let arguments = query.take_arguments();
|
|
|
|
Box::pin(try_stream! {
|
|
let s = self.run(s, arguments, 0).await?;
|
|
pin_mut!(s);
|
|
|
|
while let Some(v) = s.try_next().await? {
|
|
r#yield!(v);
|
|
}
|
|
|
|
Ok(())
|
|
})
|
|
}
|
|
|
|
fn fetch_optional<'e, 'q: 'e, E: 'q>(
|
|
self,
|
|
mut query: E,
|
|
) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
|
|
where
|
|
'c: 'e,
|
|
E: Execute<'q, Self::Database>,
|
|
{
|
|
let s = query.query();
|
|
let arguments = query.take_arguments();
|
|
|
|
Box::pin(async move {
|
|
let s = self.run(s, arguments, 1).await?;
|
|
pin_mut!(s);
|
|
|
|
while let Some(s) = s.try_next().await? {
|
|
if let Either::Right(r) = s {
|
|
return Ok(Some(r));
|
|
}
|
|
}
|
|
|
|
Ok(None)
|
|
})
|
|
}
|
|
|
|
#[doc(hidden)]
|
|
fn describe<'e, 'q: 'e, E: 'q>(
|
|
self,
|
|
query: E,
|
|
) -> BoxFuture<'e, Result<StatementInfo<Postgres>, Error>>
|
|
where
|
|
'c: 'e,
|
|
E: Execute<'q, Self::Database>,
|
|
{
|
|
let s = query.query();
|
|
|
|
Box::pin(async move {
|
|
let id = prepare(self, s, &Default::default()).await?;
|
|
|
|
self.stream.write(message::Describe::Statement(id));
|
|
self.stream.write(Flush);
|
|
|
|
self.stream.flush().await?;
|
|
|
|
let params = recv_desc_params(self).await?;
|
|
let rows = recv_desc_rows(self).await?;
|
|
|
|
let params = self.handle_parameter_description(params).await?;
|
|
|
|
self.handle_row_description(rows, true).await?;
|
|
|
|
let columns = (&*self.scratch_row_columns).clone();
|
|
let nullable = self.get_nullable_for_columns(&columns).await?;
|
|
|
|
Ok(StatementInfo {
|
|
columns,
|
|
nullable,
|
|
parameters: Some(Either::Left(params)),
|
|
})
|
|
})
|
|
}
|
|
}
|