mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-10-02 07:21:08 +00:00
fix(postgres): guarantee the type name on a PgTypeInfo to always be set
fixes #241
This commit is contained in:
parent
cd6735b5d7
commit
d360f682f8
@ -12,6 +12,7 @@ pub struct Describe<DB>
|
||||
where
|
||||
DB: Database + ?Sized,
|
||||
{
|
||||
// TODO: Describe#param_types should probably be Option<TypeInfo[]> as we either know all the params or we know none
|
||||
/// The expected types for the parameters of the query.
|
||||
pub param_types: Box<[Option<DB::TypeInfo>]>,
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
@ -7,8 +6,8 @@ use crate::connection::ConnectionSource;
|
||||
use crate::cursor::Cursor;
|
||||
use crate::executor::Execute;
|
||||
use crate::pool::Pool;
|
||||
use crate::postgres::protocol::{DataRow, Message, ReadyForQuery, RowDescription, StatementId};
|
||||
use crate::postgres::row::{Column, Statement};
|
||||
use crate::postgres::protocol::{DataRow, Message, ReadyForQuery, RowDescription};
|
||||
use crate::postgres::row::Statement;
|
||||
use crate::postgres::{PgArguments, PgConnection, PgRow, Postgres};
|
||||
|
||||
pub struct PgCursor<'c, 'q> {
|
||||
@ -53,76 +52,6 @@ impl<'c, 'q> Cursor<'c, 'q> for PgCursor<'c, 'q> {
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_row_description(conn: &mut PgConnection, rd: RowDescription) -> Statement {
|
||||
let mut names = HashMap::new();
|
||||
let mut columns = Vec::new();
|
||||
|
||||
columns.reserve(rd.fields.len());
|
||||
names.reserve(rd.fields.len());
|
||||
|
||||
for (index, field) in rd.fields.iter().enumerate() {
|
||||
if let Some(name) = &field.name {
|
||||
names.insert(name.clone(), index);
|
||||
}
|
||||
|
||||
let type_info = conn.get_type_info_by_oid(field.type_id.0);
|
||||
|
||||
columns.push(Column {
|
||||
type_info,
|
||||
format: field.type_format,
|
||||
});
|
||||
}
|
||||
|
||||
Statement {
|
||||
columns: columns.into_boxed_slice(),
|
||||
names,
|
||||
}
|
||||
}
|
||||
|
||||
// Used to describe the incoming results
|
||||
// We store the column map in an Arc and share it among all rows
|
||||
async fn expect_desc(conn: &mut PgConnection) -> crate::Result<Statement> {
|
||||
let description: Option<_> = loop {
|
||||
match conn.stream.receive().await? {
|
||||
Message::ParseComplete | Message::BindComplete => {}
|
||||
|
||||
Message::RowDescription => {
|
||||
break Some(RowDescription::read(conn.stream.buffer())?);
|
||||
}
|
||||
|
||||
Message::NoData => {
|
||||
break None;
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(
|
||||
protocol_err!("next/describe: unexpected message: {:?}", message).into(),
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(description) = description {
|
||||
Ok(parse_row_description(conn, description))
|
||||
} else {
|
||||
Ok(Statement::default())
|
||||
}
|
||||
}
|
||||
|
||||
// A form of describe that uses the statement cache
|
||||
async fn get_or_describe(
|
||||
conn: &mut PgConnection,
|
||||
id: StatementId,
|
||||
) -> crate::Result<Arc<Statement>> {
|
||||
if !conn.cache_statement.contains_key(&id) {
|
||||
let statement = expect_desc(conn).await?;
|
||||
|
||||
conn.cache_statement.insert(id, Arc::new(statement));
|
||||
}
|
||||
|
||||
Ok(Arc::clone(&conn.cache_statement[&id]))
|
||||
}
|
||||
|
||||
async fn next<'a, 'c: 'a, 'q: 'a>(
|
||||
cursor: &'a mut PgCursor<'c, 'q>,
|
||||
) -> crate::Result<Option<PgRow<'a>>> {
|
||||
@ -136,9 +65,8 @@ async fn next<'a, 'c: 'a, 'q: 'a>(
|
||||
|
||||
// If there is a statement ID, this is a non-simple or prepared query
|
||||
if let Some(statement) = statement {
|
||||
// A prepared statement will re-use the previous column map if
|
||||
// this query has been executed before
|
||||
cursor.statement = get_or_describe(&mut *conn, statement).await?;
|
||||
// A prepared statement will re-use the previous column map
|
||||
cursor.statement = Arc::clone(&conn.cache_statement[&statement]);
|
||||
}
|
||||
|
||||
// A non-prepared query must be described each time
|
||||
@ -164,8 +92,12 @@ async fn next<'a, 'c: 'a, 'q: 'a>(
|
||||
}
|
||||
|
||||
Message::RowDescription => {
|
||||
// NOTE: This is only encountered for unprepared statements
|
||||
let rd = RowDescription::read(conn.stream.buffer())?;
|
||||
cursor.statement = Arc::new(parse_row_description(conn, rd));
|
||||
cursor.statement = Arc::new(
|
||||
conn.parse_row_description(rd, Default::default(), None, false)
|
||||
.await?,
|
||||
);
|
||||
}
|
||||
|
||||
Message::DataRow => {
|
||||
|
@ -1,5 +1,6 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_util::{stream, StreamExt, TryStreamExt};
|
||||
@ -9,9 +10,11 @@ use crate::cursor::Cursor;
|
||||
use crate::describe::{Column, Describe};
|
||||
use crate::executor::{Execute, Executor, RefExecutor};
|
||||
use crate::postgres::protocol::{
|
||||
self, CommandComplete, Field, Message, ParameterDescription, ReadyForQuery, RowDescription,
|
||||
self, CommandComplete, Message, ParameterDescription, ReadyForQuery, RowDescription,
|
||||
StatementId, TypeFormat, TypeId,
|
||||
};
|
||||
use crate::postgres::row::Column as StatementColumn;
|
||||
use crate::postgres::row::Statement;
|
||||
use crate::postgres::type_info::SharedStr;
|
||||
use crate::postgres::types::try_resolve_type_name;
|
||||
use crate::postgres::{
|
||||
@ -56,10 +59,137 @@ impl PgConnection {
|
||||
query,
|
||||
});
|
||||
|
||||
// [Describe] will return the expected result columns and types
|
||||
self.write_describe(protocol::Describe::Statement(id));
|
||||
self.write_sync();
|
||||
|
||||
// Flush commands and handle ParseComplete and RowDescription
|
||||
self.wait_until_ready().await?;
|
||||
self.stream.flush().await?;
|
||||
self.is_ready = false;
|
||||
|
||||
// wait for `ParseComplete`
|
||||
match self.stream.receive().await? {
|
||||
Message::ParseComplete => {}
|
||||
message => {
|
||||
return Err(protocol_err!("run: unexpected message: {:?}", message).into());
|
||||
}
|
||||
}
|
||||
|
||||
// expecting a `ParameterDescription` next
|
||||
let pd = self.expect_param_desc().await?;
|
||||
|
||||
// expecting a `RowDescription` next (or `NoData` for an empty statement)
|
||||
let statement = self.expect_row_desc(pd).await?;
|
||||
|
||||
// cache statement ID and statement description
|
||||
self.cache_statement_id.insert(query.into(), id);
|
||||
self.cache_statement.insert(id, Arc::new(statement));
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
}
|
||||
|
||||
async fn parse_parameter_description(
|
||||
&mut self,
|
||||
pd: ParameterDescription,
|
||||
) -> crate::Result<Box<[PgTypeInfo]>> {
|
||||
let mut params = Vec::with_capacity(pd.ids.len());
|
||||
|
||||
for ty in pd.ids.iter() {
|
||||
let type_info = self.get_type_info_by_oid(ty.0, true).await?;
|
||||
|
||||
params.push(type_info);
|
||||
}
|
||||
|
||||
Ok(params.into_boxed_slice())
|
||||
}
|
||||
|
||||
pub(crate) async fn parse_row_description(
|
||||
&mut self,
|
||||
mut rd: RowDescription,
|
||||
params: Box<[PgTypeInfo]>,
|
||||
type_format: Option<TypeFormat>,
|
||||
fetch_type_info: bool,
|
||||
) -> crate::Result<Statement> {
|
||||
let mut names = HashMap::new();
|
||||
let mut columns = Vec::new();
|
||||
|
||||
columns.reserve(rd.fields.len());
|
||||
names.reserve(rd.fields.len());
|
||||
|
||||
for (index, field) in rd.fields.iter_mut().enumerate() {
|
||||
let name = if let Some(name) = field.name.take() {
|
||||
let name = SharedStr::from(name.into_string());
|
||||
names.insert(name.clone(), index);
|
||||
Some(name)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let type_info = self
|
||||
.get_type_info_by_oid(field.type_id.0, fetch_type_info)
|
||||
.await?;
|
||||
|
||||
columns.push(StatementColumn {
|
||||
type_info,
|
||||
name,
|
||||
format: type_format.unwrap_or(field.type_format),
|
||||
table_id: field.table_id,
|
||||
column_id: field.column_id,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Statement {
|
||||
params,
|
||||
columns: columns.into_boxed_slice(),
|
||||
names,
|
||||
})
|
||||
}
|
||||
|
||||
async fn expect_param_desc(&mut self) -> crate::Result<ParameterDescription> {
|
||||
let description = match self.stream.receive().await? {
|
||||
Message::ParameterDescription => ParameterDescription::read(self.stream.buffer())?,
|
||||
|
||||
message => {
|
||||
return Err(
|
||||
protocol_err!("next/describe: unexpected message: {:?}", message).into(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(description)
|
||||
}
|
||||
|
||||
// Used to describe the incoming results
|
||||
// We store the column map in an Arc and share it among all rows
|
||||
async fn expect_row_desc(&mut self, pd: ParameterDescription) -> crate::Result<Statement> {
|
||||
let description: Option<_> = match self.stream.receive().await? {
|
||||
Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?),
|
||||
|
||||
Message::NoData => None,
|
||||
|
||||
message => {
|
||||
return Err(
|
||||
protocol_err!("next/describe: unexpected message: {:?}", message).into(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let params = self.parse_parameter_description(pd).await?;
|
||||
|
||||
if let Some(description) = description {
|
||||
self.parse_row_description(description, params, Some(TypeFormat::Binary), true)
|
||||
.await
|
||||
} else {
|
||||
Ok(Statement {
|
||||
params,
|
||||
names: HashMap::new(),
|
||||
columns: Default::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn write_describe(&mut self, d: protocol::Describe) {
|
||||
self.stream.write(d);
|
||||
}
|
||||
@ -132,12 +262,6 @@ impl PgConnection {
|
||||
// Next, [Bind] attaches the arguments to the statement and creates a named portal
|
||||
self.write_bind("", statement, &mut arguments).await?;
|
||||
|
||||
// Next, [Describe] will return the expected result columns and types
|
||||
// Conditionally run [Describe] only if the results have not been cached
|
||||
if !self.cache_statement.contains_key(&statement) {
|
||||
self.write_describe(protocol::Describe::Portal(""));
|
||||
}
|
||||
|
||||
// Next, [Execute] then executes the named portal
|
||||
self.write_execute("", 0);
|
||||
|
||||
@ -161,24 +285,6 @@ impl PgConnection {
|
||||
self.stream.flush().await?;
|
||||
self.is_ready = false;
|
||||
|
||||
// only cache
|
||||
if let Some(statement) = statement {
|
||||
// prefer redundant lookup to copying the query string
|
||||
if !self.cache_statement_id.contains_key(query) {
|
||||
// wait for `ParseComplete` on the stream or the
|
||||
// error before we cache the statement
|
||||
match self.stream.receive().await? {
|
||||
Message::ParseComplete => {
|
||||
self.cache_statement_id.insert(query.into(), statement);
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(protocol_err!("run: unexpected message: {:?}", message).into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(statement)
|
||||
}
|
||||
|
||||
@ -186,71 +292,18 @@ impl PgConnection {
|
||||
&'e mut self,
|
||||
query: &'q str,
|
||||
) -> crate::Result<Describe<Postgres>> {
|
||||
self.is_ready = false;
|
||||
|
||||
let statement = self.write_prepare(query, &Default::default()).await?;
|
||||
|
||||
self.write_describe(protocol::Describe::Statement(statement));
|
||||
self.write_sync();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
let params = loop {
|
||||
match self.stream.receive().await? {
|
||||
Message::ParseComplete => {}
|
||||
|
||||
Message::ParameterDescription => {
|
||||
break ParameterDescription::read(self.stream.buffer())?;
|
||||
}
|
||||
|
||||
message => {
|
||||
return Err(protocol_err!(
|
||||
"expected ParameterDescription; received {:?}",
|
||||
message
|
||||
)
|
||||
.into());
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let result = match self.stream.receive().await? {
|
||||
Message::NoData => None,
|
||||
Message::RowDescription => Some(RowDescription::read(self.stream.buffer())?),
|
||||
|
||||
message => {
|
||||
return Err(protocol_err!(
|
||||
"expected RowDescription or NoData; received {:?}",
|
||||
message
|
||||
)
|
||||
.into());
|
||||
}
|
||||
};
|
||||
|
||||
self.wait_until_ready().await?;
|
||||
|
||||
let result_fields = result.map_or_else(Default::default, |r| r.fields);
|
||||
|
||||
let type_names = self
|
||||
.get_type_names(
|
||||
params
|
||||
.ids
|
||||
.iter()
|
||||
.cloned()
|
||||
.chain(result_fields.iter().map(|field| field.type_id)),
|
||||
)
|
||||
.await?;
|
||||
let statement_id = self.write_prepare(query, &Default::default()).await?;
|
||||
let statement = &self.cache_statement[&statement_id];
|
||||
let columns = statement.columns.to_vec();
|
||||
|
||||
Ok(Describe {
|
||||
param_types: params
|
||||
.ids
|
||||
param_types: statement
|
||||
.params
|
||||
.iter()
|
||||
.map(|id| Some(PgTypeInfo::new(*id, &type_names[&id.0])))
|
||||
.map(|info| Some(info.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
.into_boxed_slice(),
|
||||
result_columns: self
|
||||
.map_result_columns(result_fields, type_names)
|
||||
.await?
|
||||
.into_boxed_slice(),
|
||||
result_columns: self.map_result_columns(columns).await?.into_boxed_slice(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -277,71 +330,50 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
|
||||
Ok(oid)
|
||||
}
|
||||
|
||||
pub(crate) fn get_type_info_by_oid(&mut self, oid: u32) -> PgTypeInfo {
|
||||
pub(crate) async fn get_type_info_by_oid(
|
||||
&mut self,
|
||||
oid: u32,
|
||||
fetch_type_info: bool,
|
||||
) -> crate::Result<PgTypeInfo> {
|
||||
if let Some(name) = try_resolve_type_name(oid) {
|
||||
return PgTypeInfo::new(TypeId(oid), name);
|
||||
return Ok(PgTypeInfo::new(TypeId(oid), name));
|
||||
}
|
||||
|
||||
if let Some(name) = self.cache_type_name.get(&oid) {
|
||||
return PgTypeInfo::new(TypeId(oid), name);
|
||||
return Ok(PgTypeInfo::new(TypeId(oid), name));
|
||||
}
|
||||
|
||||
// NOTE: The name isn't too important for the decode lifecycle
|
||||
return PgTypeInfo::new(TypeId(oid), "");
|
||||
}
|
||||
let name = if fetch_type_info {
|
||||
// language=SQL
|
||||
let (name,): (String,) = query_as(
|
||||
"
|
||||
SELECT UPPER(typname) FROM pg_catalog.pg_type WHERE oid = $1
|
||||
",
|
||||
)
|
||||
.bind(oid)
|
||||
.fetch_one(&mut *self)
|
||||
.await?;
|
||||
|
||||
async fn get_type_names(
|
||||
&mut self,
|
||||
ids: impl IntoIterator<Item = TypeId>,
|
||||
) -> crate::Result<HashMap<u32, SharedStr>> {
|
||||
let type_ids: HashSet<u32> = ids.into_iter().map(|id| id.0).collect::<HashSet<u32>>();
|
||||
// Emplace the new type name <-> OID association in the cache
|
||||
let shared = SharedStr::from(name);
|
||||
|
||||
if type_ids.is_empty() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
self.cache_type_oid.insert(shared.clone(), oid);
|
||||
self.cache_type_name.insert(oid, shared.clone());
|
||||
|
||||
// uppercase type names are easier to visually identify
|
||||
let mut query = "select types.type_id, UPPER(pg_type.typname) from (VALUES ".to_string();
|
||||
let mut args = PgArguments::default();
|
||||
let mut pushed = false;
|
||||
shared
|
||||
} else {
|
||||
// NOTE: The name isn't too important for the decode lifecycle of TEXT
|
||||
SharedStr::Static("")
|
||||
};
|
||||
|
||||
// TODO: dedup this with the one below, ideally as an API we can export
|
||||
for (i, (&type_id, bind)) in type_ids.iter().zip((1..).step_by(2)).enumerate() {
|
||||
if pushed {
|
||||
query += ", ";
|
||||
}
|
||||
|
||||
pushed = true;
|
||||
let _ = write!(query, "(${}, ${})", bind, bind + 1);
|
||||
|
||||
// not used in the output but ensures are values are sorted correctly
|
||||
args.add(i as i32);
|
||||
args.add(type_id as i32);
|
||||
}
|
||||
|
||||
query += ") as types(idx, type_id) \
|
||||
inner join pg_catalog.pg_type on pg_type.oid = type_id \
|
||||
order by types.idx";
|
||||
|
||||
crate::query::query(&query)
|
||||
.bind_all(args)
|
||||
.try_map(|row: PgRow| -> crate::Result<(u32, SharedStr)> {
|
||||
Ok((
|
||||
row.try_get::<i32, _>(0)? as u32,
|
||||
row.try_get::<String, _>(1)?.into(),
|
||||
))
|
||||
})
|
||||
.fetch(self)
|
||||
.try_collect()
|
||||
.await
|
||||
Ok(PgTypeInfo::new(TypeId(oid), name))
|
||||
}
|
||||
|
||||
async fn map_result_columns(
|
||||
&mut self,
|
||||
fields: Box<[Field]>,
|
||||
type_names: HashMap<u32, SharedStr>,
|
||||
columns: Vec<StatementColumn>,
|
||||
) -> crate::Result<Vec<Column<Postgres>>> {
|
||||
if fields.is_empty() {
|
||||
if columns.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
@ -349,7 +381,7 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
|
||||
let mut pushed = false;
|
||||
let mut args = PgArguments::default();
|
||||
|
||||
for (i, (field, bind)) in fields.iter().zip((1..).step_by(3)).enumerate() {
|
||||
for (i, (column, bind)) in columns.iter().zip((1..).step_by(3)).enumerate() {
|
||||
if pushed {
|
||||
query += ", ";
|
||||
}
|
||||
@ -364,8 +396,8 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
|
||||
);
|
||||
|
||||
args.add(i as i32);
|
||||
args.add(field.table_id.map(|id| id as i32));
|
||||
args.add(field.column_id);
|
||||
args.add(column.table_id.map(|id| id as i32));
|
||||
args.add(column.column_id);
|
||||
}
|
||||
|
||||
query += ") as col(idx, table_id, col_idx) \
|
||||
@ -383,23 +415,20 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
|
||||
Ok((idx, non_null))
|
||||
})
|
||||
.fetch(self)
|
||||
.zip(stream::iter(fields.into_vec().into_iter().enumerate()))
|
||||
.map(|(row, (fidx, field))| -> crate::Result<Column<_>> {
|
||||
.zip(stream::iter(columns.into_iter().enumerate()))
|
||||
.map(|(row, (fidx, column))| -> crate::Result<Column<_>> {
|
||||
let (idx, non_null) = row?;
|
||||
|
||||
if idx != fidx as i32 {
|
||||
return Err(
|
||||
protocol_err!("missing field from query, field: {:?}", field).into(),
|
||||
protocol_err!("missing field from query, field: {:?}", column).into(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Column {
|
||||
name: field.name,
|
||||
table_id: field.table_id,
|
||||
type_info: Some(PgTypeInfo::new(
|
||||
field.type_id,
|
||||
&type_names[&field.type_id.0],
|
||||
)),
|
||||
name: column.name.map(|name| (&*name).into()),
|
||||
table_id: column.table_id,
|
||||
type_info: Some(column.type_info),
|
||||
non_null,
|
||||
})
|
||||
})
|
||||
|
@ -2,6 +2,7 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::postgres::protocol::{DataRow, TypeFormat};
|
||||
use crate::postgres::type_info::SharedStr;
|
||||
use crate::postgres::value::PgValue;
|
||||
use crate::postgres::{PgTypeInfo, Postgres};
|
||||
use crate::row::{ColumnIndex, Row};
|
||||
@ -10,17 +11,24 @@ use crate::row::{ColumnIndex, Row};
|
||||
// For Postgres, each column has an OID and a format (binary or text)
|
||||
// For simple (unprepared) queries, format will always be text
|
||||
// For prepared queries, format will _almost_ always be binary
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Column {
|
||||
pub(crate) name: Option<SharedStr>,
|
||||
pub(crate) type_info: PgTypeInfo,
|
||||
pub(crate) format: TypeFormat,
|
||||
pub(crate) table_id: Option<u32>,
|
||||
pub(crate) column_id: i16,
|
||||
}
|
||||
|
||||
// A statement description containing the column information used to
|
||||
// properly decode data
|
||||
#[derive(Default)]
|
||||
pub(crate) struct Statement {
|
||||
// paramaters
|
||||
pub(crate) params: Box<[PgTypeInfo]>,
|
||||
|
||||
// column name -> position
|
||||
pub(crate) names: HashMap<Box<str>, usize>,
|
||||
pub(crate) names: HashMap<SharedStr, usize>,
|
||||
|
||||
// all columns
|
||||
pub(crate) columns: Box<[Column]>,
|
||||
|
@ -3,6 +3,7 @@ use crate::types::TypeInfo;
|
||||
use std::borrow::Borrow;
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -135,7 +136,7 @@ impl TypeInfo for PgTypeInfo {
|
||||
}
|
||||
|
||||
/// Copy of `Cow` but for strings; clones guaranteed to be cheap.
|
||||
#[derive(Clone, Debug, PartialEq, Hash, Eq)]
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub(crate) enum SharedStr {
|
||||
Static(&'static str),
|
||||
Arc(Arc<str>),
|
||||
@ -152,6 +153,14 @@ impl Deref for SharedStr {
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for SharedStr {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
// Forward the hash to the string representation of this
|
||||
// A derive(Hash) encodes the enum discriminant
|
||||
(&**self).hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl Borrow<str> for SharedStr {
|
||||
fn borrow(&self) -> &str {
|
||||
&**self
|
||||
|
@ -72,9 +72,11 @@ impl<'c> PgValue<'c> {
|
||||
impl<'c> RawValue<'c> for PgValue<'c> {
|
||||
type Database = Postgres;
|
||||
|
||||
// The public type_info is used for type compatibility checks
|
||||
fn type_info(&self) -> Option<PgTypeInfo> {
|
||||
if let (Some(type_info), Some(_)) = (&self.type_info, &self.data) {
|
||||
Some(type_info.clone())
|
||||
// For TEXT encoding the type defined on the value is unreliable
|
||||
if matches!(self.data, Some(PgData::Binary(_))) {
|
||||
self.type_info.clone()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use sqlx::{postgres::PgQueryAs, Executor, Postgres};
|
||||
use sqlx::{postgres::PgQueryAs, Connection, Cursor, Executor, FromRow, Postgres};
|
||||
use sqlx_test::{new, test_type};
|
||||
use std::fmt::Debug;
|
||||
|
||||
@ -16,7 +16,7 @@ enum Weak {
|
||||
Three = 4,
|
||||
}
|
||||
|
||||
// "Strong" enums can map to TEXT (25) or a custom enum type
|
||||
// "Strong" enums can map to TEXT (25)
|
||||
#[derive(PartialEq, Debug, sqlx::Type)]
|
||||
#[sqlx(rename = "text")]
|
||||
#[sqlx(rename_all = "lowercase")]
|
||||
@ -28,6 +28,16 @@ enum Strong {
|
||||
Three,
|
||||
}
|
||||
|
||||
// "Strong" enum can map to a custom type
|
||||
#[derive(PartialEq, Debug, sqlx::Type)]
|
||||
#[sqlx(rename = "mood")]
|
||||
#[sqlx(rename_all = "lowercase")]
|
||||
enum Mood {
|
||||
Ok,
|
||||
Happy,
|
||||
Sad,
|
||||
}
|
||||
|
||||
// Records must map to a custom type
|
||||
// Note that all types are types in Postgres
|
||||
#[derive(PartialEq, Debug, sqlx::Type)]
|
||||
@ -61,6 +71,100 @@ test_type!(strong_enum(
|
||||
"'four'::text" == Strong::Three
|
||||
));
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn test_enum_type() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
DO $$ BEGIN
|
||||
|
||||
CREATE TYPE mood AS ENUM ( 'ok', 'happy', 'sad' );
|
||||
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS people (
|
||||
id serial PRIMARY KEY,
|
||||
mood mood not null
|
||||
);
|
||||
|
||||
TRUNCATE people;
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Drop and re-acquire the connection
|
||||
conn.close().await?;
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
// Select from table test
|
||||
let (people_id,): (i32,) = sqlx::query_as(
|
||||
"
|
||||
INSERT INTO people (mood)
|
||||
VALUES ($1)
|
||||
RETURNING id
|
||||
",
|
||||
)
|
||||
.bind(Mood::Sad)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
// Drop and re-acquire the connection
|
||||
conn.close().await?;
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct PeopleRow {
|
||||
id: i32,
|
||||
mood: Mood,
|
||||
}
|
||||
|
||||
let rec: PeopleRow = sqlx::query_as(
|
||||
"
|
||||
SELECT id, mood FROM people WHERE id = $1
|
||||
",
|
||||
)
|
||||
.bind(people_id)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert_eq!(rec.id, people_id);
|
||||
assert_eq!(rec.mood, Mood::Sad);
|
||||
|
||||
// Drop and re-acquire the connection
|
||||
conn.close().await?;
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
let stmt = format!("SELECT id, mood FROM people WHERE id = {}", people_id);
|
||||
dbg!(&stmt);
|
||||
let mut cursor = conn.fetch(&*stmt);
|
||||
|
||||
let row = cursor.next().await?.unwrap();
|
||||
let rec = PeopleRow::from_row(&row)?;
|
||||
|
||||
assert_eq!(rec.id, people_id);
|
||||
assert_eq!(rec.mood, Mood::Sad);
|
||||
|
||||
// Normal type equivalency test
|
||||
|
||||
let rec: (bool, Mood) = sqlx::query_as(
|
||||
"
|
||||
SELECT $1 = 'happy'::mood, $1
|
||||
",
|
||||
)
|
||||
.bind(&Mood::Happy)
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert!(rec.0);
|
||||
assert_eq!(rec.1, Mood::Happy);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
|
||||
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
|
||||
async fn test_record_type() -> anyhow::Result<()> {
|
||||
|
Loading…
x
Reference in New Issue
Block a user