fix(postgres): guarantee the type name on a PgTypeInfo to always be set

fixes #241
This commit is contained in:
Ryan Leckey 2020-04-10 13:37:08 -07:00
parent cd6735b5d7
commit d360f682f8
7 changed files with 316 additions and 231 deletions

View File

@ -12,6 +12,7 @@ pub struct Describe<DB>
where
DB: Database + ?Sized,
{
// TODO: Describe#param_types should probably be Option<TypeInfo[]> as we either know all the params or we know none
/// The expected types for the parameters of the query.
pub param_types: Box<[Option<DB::TypeInfo>]>,

View File

@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::sync::Arc;
use futures_core::future::BoxFuture;
@ -7,8 +6,8 @@ use crate::connection::ConnectionSource;
use crate::cursor::Cursor;
use crate::executor::Execute;
use crate::pool::Pool;
use crate::postgres::protocol::{DataRow, Message, ReadyForQuery, RowDescription, StatementId};
use crate::postgres::row::{Column, Statement};
use crate::postgres::protocol::{DataRow, Message, ReadyForQuery, RowDescription};
use crate::postgres::row::Statement;
use crate::postgres::{PgArguments, PgConnection, PgRow, Postgres};
pub struct PgCursor<'c, 'q> {
@ -53,76 +52,6 @@ impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> {
}
}
fn parse_row_description(conn: &mut PgConnection, rd: RowDescription) -> Statement {
let mut names = HashMap::new();
let mut columns = Vec::new();
columns.reserve(rd.fields.len());
names.reserve(rd.fields.len());
for (index, field) in rd.fields.iter().enumerate() {
if let Some(name) = &field.name {
names.insert(name.clone(), index);
}
let type_info = conn.get_type_info_by_oid(field.type_id.0);
columns.push(Column {
type_info,
format: field.type_format,
});
}
Statement {
columns: columns.into_boxed_slice(),
names,
}
}
// Used to describe the incoming results
// We store the column map in an Arc and share it among all rows
async fn expect_desc(conn: &mut PgConnection) -> crate::Result<Statement> {
let description: Option<_> = loop {
match conn.stream.receive().await? {
Message::ParseComplete | Message::BindComplete => {}
Message::RowDescription => {
break Some(RowDescription::read(conn.stream.buffer())?);
}
Message::NoData => {
break None;
}
message => {
return Err(
protocol_err!("next/describe: unexpected message: {:?}", message).into(),
);
}
}
};
if let Some(description) = description {
Ok(parse_row_description(conn, description))
} else {
Ok(Statement::default())
}
}
// A form of describe that uses the statement cache
async fn get_or_describe(
conn: &mut PgConnection,
id: StatementId,
) -> crate::Result<Arc<Statement>> {
if !conn.cache_statement.contains_key(&id) {
let statement = expect_desc(conn).await?;
conn.cache_statement.insert(id, Arc::new(statement));
}
Ok(Arc::clone(&conn.cache_statement[&id]))
}
async fn next<'a, 'c: 'a, 'q: 'a>(
cursor: &'a mut PgCursor<'c, 'q>,
) -> crate::Result<Option<PgRow<'a>>> {
@ -136,9 +65,8 @@ async fn next<'a, 'c: 'a, 'q: 'a>(
// If there is a statement ID, this is a non-simple or prepared query
if let Some(statement) = statement {
// A prepared statement will re-use the previous column map if
// this query has been executed before
cursor.statement = get_or_describe(&mut *conn, statement).await?;
// A prepared statement will re-use the previous column map
cursor.statement = Arc::clone(&conn.cache_statement[&statement]);
}
// A non-prepared query must be described each time
@ -164,8 +92,12 @@ async fn next<'a, 'c: 'a, 'q: 'a>(
}
Message::RowDescription => {
// NOTE: This is only encountered for unprepared statements
let rd = RowDescription::read(conn.stream.buffer())?;
cursor.statement = Arc::new(parse_row_description(conn, rd));
cursor.statement = Arc::new(
conn.parse_row_description(rd, Default::default(), None, false)
.await?,
);
}
Message::DataRow => {

View File

@ -1,5 +1,6 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::fmt::Write;
use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_util::{stream, StreamExt, TryStreamExt};
@ -9,9 +10,11 @@ use crate::cursor::Cursor;
use crate::describe::{Column, Describe};
use crate::executor::{Execute, Executor, RefExecutor};
use crate::postgres::protocol::{
self, CommandComplete, Field, Message, ParameterDescription, ReadyForQuery, RowDescription,
self, CommandComplete, Message, ParameterDescription, ReadyForQuery, RowDescription,
StatementId, TypeFormat, TypeId,
};
use crate::postgres::row::Column as StatementColumn;
use crate::postgres::row::Statement;
use crate::postgres::type_info::SharedStr;
use crate::postgres::types::try_resolve_type_name;
use crate::postgres::{
@ -56,10 +59,137 @@ impl PgConnection {
query,
});
// [Describe] will return the expected result columns and types
self.write_describe(protocol::Describe::Statement(id));
self.write_sync();
// Flush commands and handle ParseComplete and RowDescription
self.wait_until_ready().await?;
self.stream.flush().await?;
self.is_ready = false;
// wait for `ParseComplete`
match self.stream.receive().await? {
Message::ParseComplete => {}
message => {
return Err(protocol_err!("run: unexpected message: {:?}", message).into());
}
}
// expecting a `ParameterDescription` next
let pd = self.expect_param_desc().await?;
// expecting a `RowDescription` next (or `NoData` for an empty statement)
let statement = self.expect_row_desc(pd).await?;
// cache statement ID and statement description
self.cache_statement_id.insert(query.into(), id);
self.cache_statement.insert(id, Arc::new(statement));
Ok(id)
}
}
async fn parse_parameter_description(
&mut self,
pd: ParameterDescription,
) -> crate::Result<Box<[PgTypeInfo]>> {
let mut params = Vec::with_capacity(pd.ids.len());
for ty in pd.ids.iter() {
let type_info = self.get_type_info_by_oid(ty.0, true).await?;
params.push(type_info);
}
Ok(params.into_boxed_slice())
}
pub(crate) async fn parse_row_description(
&mut self,
mut rd: RowDescription,
params: Box<[PgTypeInfo]>,
type_format: Option<TypeFormat>,
fetch_type_info: bool,
) -> crate::Result<Statement> {
let mut names = HashMap::new();
let mut columns = Vec::new();
columns.reserve(rd.fields.len());
names.reserve(rd.fields.len());
for (index, field) in rd.fields.iter_mut().enumerate() {
let name = if let Some(name) = field.name.take() {
let name = SharedStr::from(name.into_string());
names.insert(name.clone(), index);
Some(name)
} else {
None
};
let type_info = self
.get_type_info_by_oid(field.type_id.0, fetch_type_info)
.await?;
columns.push(StatementColumn {
type_info,
name,
format: type_format.unwrap_or(field.type_format),
table_id: field.table_id,
column_id: field.column_id,
});
}
Ok(Statement {
params,
columns: columns.into_boxed_slice(),
names,
})
}
async fn expect_param_desc(&mut self) -> crate::Result<ParameterDescription> {
let description = match self.stream.receive().await? {
Message::ParameterDescription => ParameterDescription::read(self.stream.buffer())?,
message => {
return Err(
protocol_err!("next/describe: unexpected message: {:?}", message).into(),
);
}
};
Ok(description)
}
// Used to describe the incoming results
// We store the column map in an Arc and share it among all rows
async fn expect_row_desc(&mut self, pd: ParameterDescription) -> crate::Result<Statement> {
let description: Option<_> = match self.stream.receive().await? {
Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?),
Message::NoData => None,
message => {
return Err(
protocol_err!("next/describe: unexpected message: {:?}", message).into(),
);
}
};
let params = self.parse_parameter_description(pd).await?;
if let Some(description) = description {
self.parse_row_description(description, params, Some(TypeFormat::Binary), true)
.await
} else {
Ok(Statement {
params,
names: HashMap::new(),
columns: Default::default(),
})
}
}
pub(crate) fn write_describe(&mut self, d: protocol::Describe) {
self.stream.write(d);
}
@ -132,12 +262,6 @@ impl PgConnection {
// Next, [Bind] attaches the arguments to the statement and creates a named portal
self.write_bind("", statement, &mut arguments).await?;
// Next, [Describe] will return the expected result columns and types
// Conditionally run [Describe] only if the results have not been cached
if !self.cache_statement.contains_key(&statement) {
self.write_describe(protocol::Describe::Portal(""));
}
// Next, [Execute] then executes the named portal
self.write_execute("", 0);
@ -161,24 +285,6 @@ impl PgConnection {
self.stream.flush().await?;
self.is_ready = false;
// only cache
if let Some(statement) = statement {
// prefer redundant lookup to copying the query string
if !self.cache_statement_id.contains_key(query) {
// wait for `ParseComplete` on the stream or the
// error before we cache the statement
match self.stream.receive().await? {
Message::ParseComplete => {
self.cache_statement_id.insert(query.into(), statement);
}
message => {
return Err(protocol_err!("run: unexpected message: {:?}", message).into());
}
}
}
}
Ok(statement)
}
@ -186,71 +292,18 @@ impl PgConnection {
&'e mut self,
query: &'q str,
) -> crate::Result<Describe<Postgres>> {
self.is_ready = false;
let statement = self.write_prepare(query, &Default::default()).await?;
self.write_describe(protocol::Describe::Statement(statement));
self.write_sync();
self.stream.flush().await?;
let params = loop {
match self.stream.receive().await? {
Message::ParseComplete => {}
Message::ParameterDescription => {
break ParameterDescription::read(self.stream.buffer())?;
}
message => {
return Err(protocol_err!(
"expected ParameterDescription; received {:?}",
message
)
.into());
}
};
};
let result = match self.stream.receive().await? {
Message::NoData => None,
Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?),
message => {
return Err(protocol_err!(
"expected RowDescription or NoData; received {:?}",
message
)
.into());
}
};
self.wait_until_ready().await?;
let result_fields = result.map_or_else(Default::default, |r| r.fields);
let type_names = self
.get_type_names(
params
.ids
.iter()
.cloned()
.chain(result_fields.iter().map(|field| field.type_id)),
)
.await?;
let statement_id = self.write_prepare(query, &Default::default()).await?;
let statement = &self.cache_statement[&statement_id];
let columns = statement.columns.to_vec();
Ok(Describe {
param_types: params
.ids
param_types: statement
.params
.iter()
.map(|id| Some(PgTypeInfo::new(*id, &type_names[&id.0])))
.map(|info| Some(info.clone()))
.collect::<Vec<_>>()
.into_boxed_slice(),
result_columns: self
.map_result_columns(result_fields, type_names)
.await?
.into_boxed_slice(),
result_columns: self.map_result_columns(columns).await?.into_boxed_slice(),
})
}
@ -277,71 +330,50 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
Ok(oid)
}
pub(crate) fn get_type_info_by_oid(&mut self, oid: u32) -> PgTypeInfo {
pub(crate) async fn get_type_info_by_oid(
&mut self,
oid: u32,
fetch_type_info: bool,
) -> crate::Result<PgTypeInfo> {
if let Some(name) = try_resolve_type_name(oid) {
return PgTypeInfo::new(TypeId(oid), name);
return Ok(PgTypeInfo::new(TypeId(oid), name));
}
if let Some(name) = self.cache_type_name.get(&oid) {
return PgTypeInfo::new(TypeId(oid), name);
return Ok(PgTypeInfo::new(TypeId(oid), name));
}
// NOTE: The name isn't too important for the decode lifecycle
return PgTypeInfo::new(TypeId(oid), "");
}
let name = if fetch_type_info {
// language=SQL
let (name,): (String,) = query_as(
"
SELECT UPPER(typname) FROM pg_catalog.pg_type WHERE oid = $1
",
)
.bind(oid)
.fetch_one(&mut *self)
.await?;
async fn get_type_names(
&mut self,
ids: impl IntoIterator<Item = TypeId>,
) -> crate::Result<HashMap<u32, SharedStr>> {
let type_ids: HashSet<u32> = ids.into_iter().map(|id| id.0).collect::<HashSet<u32>>();
// Emplace the new type name <-> OID association in the cache
let shared = SharedStr::from(name);
if type_ids.is_empty() {
return Ok(HashMap::new());
}
self.cache_type_oid.insert(shared.clone(), oid);
self.cache_type_name.insert(oid, shared.clone());
// uppercase type names are easier to visually identify
let mut query = "select types.type_id, UPPER(pg_type.typname) from (VALUES ".to_string();
let mut args = PgArguments::default();
let mut pushed = false;
shared
} else {
// NOTE: The name isn't too important for the decode lifecycle of TEXT
SharedStr::Static("")
};
// TODO: dedup this with the one below, ideally as an API we can export
for (i, (&type_id, bind)) in type_ids.iter().zip((1..).step_by(2)).enumerate() {
if pushed {
query += ", ";
}
pushed = true;
let _ = write!(query, "(${}, ${})", bind, bind + 1);
// not used in the output but ensures are values are sorted correctly
args.add(i as i32);
args.add(type_id as i32);
}
query += ") as types(idx, type_id) \
inner join pg_catalog.pg_type on pg_type.oid = type_id \
order by types.idx";
crate::query::query(&query)
.bind_all(args)
.try_map(|row: PgRow| -> crate::Result<(u32, SharedStr)> {
Ok((
row.try_get::<i32, _>(0)? as u32,
row.try_get::<String, _>(1)?.into(),
))
})
.fetch(self)
.try_collect()
.await
Ok(PgTypeInfo::new(TypeId(oid), name))
}
async fn map_result_columns(
&mut self,
fields: Box<[Field]>,
type_names: HashMap<u32, SharedStr>,
columns: Vec<StatementColumn>,
) -> crate::Result<Vec<Column<Postgres>>> {
if fields.is_empty() {
if columns.is_empty() {
return Ok(vec![]);
}
@ -349,7 +381,7 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
let mut pushed = false;
let mut args = PgArguments::default();
for (i, (field, bind)) in fields.iter().zip((1..).step_by(3)).enumerate() {
for (i, (column, bind)) in columns.iter().zip((1..).step_by(3)).enumerate() {
if pushed {
query += ", ";
}
@ -364,8 +396,8 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
);
args.add(i as i32);
args.add(field.table_id.map(|id| id as i32));
args.add(field.column_id);
args.add(column.table_id.map(|id| id as i32));
args.add(column.column_id);
}
query += ") as col(idx, table_id, col_idx) \
@ -383,23 +415,20 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
Ok((idx, non_null))
})
.fetch(self)
.zip(stream::iter(fields.into_vec().into_iter().enumerate()))
.map(|(row, (fidx, field))| -> crate::Result<Column<_>> {
.zip(stream::iter(columns.into_iter().enumerate()))
.map(|(row, (fidx, column))| -> crate::Result<Column<_>> {
let (idx, non_null) = row?;
if idx != fidx as i32 {
return Err(
protocol_err!("missing field from query, field: {:?}", field).into(),
protocol_err!("missing field from query, field: {:?}", column).into(),
);
}
Ok(Column {
name: field.name,
table_id: field.table_id,
type_info: Some(PgTypeInfo::new(
field.type_id,
&type_names[&field.type_id.0],
)),
name: column.name.map(|name| (&*name).into()),
table_id: column.table_id,
type_info: Some(column.type_info),
non_null,
})
})

View File

@ -2,6 +2,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use crate::postgres::protocol::{DataRow, TypeFormat};
use crate::postgres::type_info::SharedStr;
use crate::postgres::value::PgValue;
use crate::postgres::{PgTypeInfo, Postgres};
use crate::row::{ColumnIndex, Row};
@ -10,17 +11,24 @@ use crate::row::{ColumnIndex, Row};
// For Postgres, each column has an OID and a format (binary or text)
// For simple (unprepared) queries, format will always be text
// For prepared queries, format will _almost_ always be binary
#[derive(Clone, Debug)]
pub(crate) struct Column {
pub(crate) name: Option<SharedStr>,
pub(crate) type_info: PgTypeInfo,
pub(crate) format: TypeFormat,
pub(crate) table_id: Option<u32>,
pub(crate) column_id: i16,
}
// A statement description containing the column information used to
// properly decode data
#[derive(Default)]
pub(crate) struct Statement {
// paramaters
pub(crate) params: Box<[PgTypeInfo]>,
// column name -> position
pub(crate) names: HashMap<Box<str>, usize>,
pub(crate) names: HashMap<SharedStr, usize>,
// all columns
pub(crate) columns: Box<[Column]>,

View File

@ -3,6 +3,7 @@ use crate::types::TypeInfo;
use std::borrow::Borrow;
use std::fmt;
use std::fmt::Display;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;
@ -135,7 +136,7 @@ impl TypeInfo for PgTypeInfo {
}
/// Copy of `Cow` but for strings; clones guaranteed to be cheap.
#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum SharedStr {
Static(&'static str),
Arc(Arc<str>),
@ -152,6 +153,14 @@ impl Deref for SharedStr {
}
}
impl Hash for SharedStr {
fn hash<H: Hasher>(&self, state: &mut H) {
// Forward the hash to the string representation of this
// A derive(Hash) encodes the enum discriminant
(&**self).hash(state);
}
}
impl Borrow<str> for SharedStr {
fn borrow(&self) -> &str {
&**self

View File

@ -72,9 +72,11 @@ impl<'c> PgValue<'c> {
impl<'c> RawValue<'c> for PgValue<'c> {
type Database = Postgres;
// The public type_info is used for type compatibility checks
fn type_info(&self) -> Option<PgTypeInfo> {
if let (Some(type_info), Some(_)) = (&self.type_info, &self.data) {
Some(type_info.clone())
// For TEXT encoding the type defined on the value is unreliable
if matches!(self.data, Some(PgData::Binary(_))) {
self.type_info.clone()
} else {
None
}

View File

@ -1,4 +1,4 @@
use sqlx::{postgres::PgQueryAs, Executor, Postgres};
use sqlx::{postgres::PgQueryAs, Connection, Cursor, Executor, FromRow, Postgres};
use sqlx_test::{new, test_type};
use std::fmt::Debug;
@ -16,7 +16,7 @@ enum Weak {
Three = 4,
}
// "Strong" enums can map to TEXT (25) or a custom enum type
// "Strong" enums can map to TEXT (25)
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(rename = "text")]
#[sqlx(rename_all = "lowercase")]
@ -28,6 +28,16 @@ enum Strong {
Three,
}
// "Strong" enum can map to a custom type
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(rename = "mood")]
#[sqlx(rename_all = "lowercase")]
enum Mood {
Ok,
Happy,
Sad,
}
// Records must map to a custom type
// Note that all types are types in Postgres
#[derive(PartialEq, Debug, sqlx::Type)]
@ -61,6 +71,100 @@ test_type!(strong_enum(
"'four'::text" == Strong::Three
));
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_enum_type() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
conn.execute(
r#"
DO $$ BEGIN
CREATE TYPE mood AS ENUM ( 'ok', 'happy', 'sad' );
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
CREATE TABLE IF NOT EXISTS people (
id serial PRIMARY KEY,
mood mood not null
);
TRUNCATE people;
"#,
)
.await?;
// Drop and re-acquire the connection
conn.close().await?;
let mut conn = new::<Postgres>().await?;
// Select from table test
let (people_id,): (i32,) = sqlx::query_as(
"
INSERT INTO people (mood)
VALUES ($1)
RETURNING id
",
)
.bind(Mood::Sad)
.fetch_one(&mut conn)
.await?;
// Drop and re-acquire the connection
conn.close().await?;
let mut conn = new::<Postgres>().await?;
#[derive(sqlx::FromRow)]
struct PeopleRow {
id: i32,
mood: Mood,
}
let rec: PeopleRow = sqlx::query_as(
"
SELECT id, mood FROM people WHERE id = $1
",
)
.bind(people_id)
.fetch_one(&mut conn)
.await?;
assert_eq!(rec.id, people_id);
assert_eq!(rec.mood, Mood::Sad);
// Drop and re-acquire the connection
conn.close().await?;
let mut conn = new::<Postgres>().await?;
let stmt = format!("SELECT id, mood FROM people WHERE id = {}", people_id);
dbg!(&stmt);
let mut cursor = conn.fetch(&*stmt);
let row = cursor.next().await?.unwrap();
let rec = PeopleRow::from_row(&row)?;
assert_eq!(rec.id, people_id);
assert_eq!(rec.mood, Mood::Sad);
// Normal type equivalency test
let rec: (bool, Mood) = sqlx::query_as(
"
SELECT $1 = 'happy'::mood, $1
",
)
.bind(&Mood::Happy)
.fetch_one(&mut conn)
.await?;
assert!(rec.0);
assert_eq!(rec.1, Mood::Happy);
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_record_type() -> anyhow::Result<()> {