refactor: add origin information to Column

This commit is contained in:
Austin Bonander 2024-09-18 01:55:59 -07:00
parent e775d2a3eb
commit bf90a477a1
11 changed files with 243 additions and 7 deletions

View File

@ -2,6 +2,7 @@ use crate::database::Database;
use crate::error::Error;
use std::fmt::Debug;
use std::sync::Arc;
pub trait Column: 'static + Send + Sync + Debug {
type Database: Database<Column = Self>;
@ -20,6 +21,59 @@ pub trait Column: 'static + Send + Sync + Debug {
/// Gets the type information for the column.
fn type_info(&self) -> &<Self::Database as Database>::TypeInfo;
/// If this column comes from a table, return the table and original column name.
///
/// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression
/// or else the source table could not be determined.
///
/// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information,
/// or has not overridden this method.
// This method returns an owned value instead of a reference,
// to give the implementor more flexibility.
fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown }
}
/// A [`Column`] that originates from a table.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct TableColumn {
/// The name of the table (optionally schema-qualified) that the column comes from.
pub table: Arc<str>,
/// The original name of the column.
pub name: Arc<str>,
}
/// The possible statuses for our knowledge of the origin of a [`Column`].
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub enum ColumnOrigin {
/// The column is known to originate from a table.
///
/// Included is the table name and original column name.
Table(TableColumn),
/// The column originates from an expression, or else its origin could not be determined.
Expression,
/// The database driver does not know the column origin at this time.
///
/// This may happen if:
/// * The connection is in the middle of executing a query,
/// and cannot query the catalog to fetch this information.
/// * The connection does not have access to the database catalog.
/// * The implementation of [`Column`] did not override [`Column::origin()`].
#[default]
Unknown,
}
impl ColumnOrigin {
/// Returns the true column origin, if known.
pub fn table_column(&self) -> Option<&TableColumn> {
if let Self::Table(table_column) = self {
Some(table_column)
} else {
None
}
}
}
/// A type that can be used to index into a [`Row`] or [`Statement`].

View File

@ -10,6 +10,9 @@ pub struct MySqlColumn {
pub(crate) name: UStr,
pub(crate) type_info: MySqlTypeInfo,
#[cfg_attr(feature = "offline", serde(default))]
pub(crate) origin: ColumnOrigin,
#[cfg_attr(feature = "offline", serde(skip))]
pub(crate) flags: Option<ColumnFlags>,
}
@ -28,4 +31,8 @@ impl Column for MySqlColumn {
fn type_info(&self) -> &MySqlTypeInfo {
&self.type_info
}
fn origin(&self) -> ColumnOrigin {
self.origin.clone()
}
}

View File

@ -23,6 +23,7 @@ use futures_core::stream::BoxStream;
use futures_core::Stream;
use futures_util::{pin_mut, TryStreamExt};
use std::{borrow::Cow, sync::Arc};
use sqlx_core::column::{ColumnOrigin, TableColumn};
impl MySqlConnection {
async fn prepare_statement<'c>(
@ -382,11 +383,30 @@ async fn recv_result_columns(
fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MySqlColumn, Error> {
// if the alias is empty, use the alias
// only then use the name
let column_name = def.name()?;
let name = match (def.name()?, def.alias()?) {
(_, alias) if !alias.is_empty() => UStr::new(alias),
(name, _) => UStr::new(name),
};
let table = def.table()?;
let origin = if table.is_empty() {
ColumnOrigin::Expression
} else {
let schema = def.schema()?;
ColumnOrigin::Table(TableColumn {
table: if !schema.is_empty() {
format!("{schema}.{table}").into()
} else {
table.into()
},
name: column_name.into(),
})
};
let type_info = MySqlTypeInfo::from_column(def);
Ok(MySqlColumn {
@ -394,6 +414,7 @@ fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MyS
type_info,
ordinal,
flags: Some(def.flags),
origin,
})
}

View File

@ -1,4 +1,4 @@
use std::str::from_utf8;
use std::str;
use bitflags::bitflags;
use bytes::{Buf, Bytes};
@ -104,11 +104,9 @@ pub enum ColumnType {
pub(crate) struct ColumnDefinition {
#[allow(unused)]
catalog: Bytes,
#[allow(unused)]
schema: Bytes,
#[allow(unused)]
table_alias: Bytes,
#[allow(unused)]
table: Bytes,
alias: Bytes,
name: Bytes,
@ -125,12 +123,20 @@ impl ColumnDefinition {
// NOTE: strings in-protocol are transmitted according to the client character set
// as this is UTF-8, all these strings should be UTF-8
pub(crate) fn schema(&self) -> Result<&str, Error> {
str::from_utf8(&self.schema).map_err(Error::protocol)
}
pub(crate) fn table(&self) -> Result<&str, Error> {
str::from_utf8(&self.table).map_err(Error::protocol)
}
pub(crate) fn name(&self) -> Result<&str, Error> {
from_utf8(&self.name).map_err(Error::protocol)
str::from_utf8(&self.name).map_err(Error::protocol)
}
pub(crate) fn alias(&self) -> Result<&str, Error> {
from_utf8(&self.alias).map_err(Error::protocol)
str::from_utf8(&self.alias).map_err(Error::protocol)
}
}

View File

@ -2,6 +2,7 @@ use crate::ext::ustr::UStr;
use crate::{PgTypeInfo, Postgres};
pub(crate) use sqlx_core::column::{Column, ColumnIndex};
use sqlx_core::column::ColumnOrigin;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
@ -9,6 +10,10 @@ pub struct PgColumn {
pub(crate) ordinal: usize,
pub(crate) name: UStr,
pub(crate) type_info: PgTypeInfo,
#[cfg_attr(feature = "offline", serde(default))]
pub(crate) origin: ColumnOrigin,
#[cfg_attr(feature = "offline", serde(skip))]
pub(crate) relation_id: Option<crate::types::Oid>,
#[cfg_attr(feature = "offline", serde(skip))]
@ -51,4 +56,8 @@ impl Column for PgColumn {
fn type_info(&self) -> &PgTypeInfo {
&self.type_info
}
fn origin(&self) -> ColumnOrigin {
self.origin.clone()
}
}

View File

@ -1,3 +1,4 @@
use std::collections::btree_map;
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::StatementId;
@ -13,6 +14,9 @@ use crate::{PgColumn, PgConnection, PgTypeInfo};
use smallvec::SmallVec;
use sqlx_core::query_builder::QueryBuilder;
use std::sync::Arc;
use sqlx_core::column::{ColumnOrigin, TableColumn};
use sqlx_core::hash_map;
use crate::connection::TableColumns;
/// Describes the type of the `pg_type.typtype` column
///
@ -121,6 +125,12 @@ impl PgConnection {
let type_info = self
.maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
.await?;
let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) {
self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await?
} else {
ColumnOrigin::Expression
};
let column = PgColumn {
ordinal: index,
@ -128,6 +138,7 @@ impl PgConnection {
type_info,
relation_id: field.relation_id,
relation_attribute_no: field.relation_attribute_no,
origin,
};
columns.push(column);
@ -189,6 +200,54 @@ impl PgConnection {
Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
}
}
async fn maybe_fetch_column_origin(
&mut self,
relation_id: Oid,
attribute_no: i16,
should_fetch: bool,
) -> Result<ColumnOrigin, Error> {
let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) {
hash_map::Entry::Occupied(table_columns) => {
table_columns.into_mut()
},
hash_map::Entry::Vacant(vacant) => {
if !should_fetch { return Ok(ColumnOrigin::Unknown); }
let table_name: String = query_scalar("SELECT $1::oid::regclass::text")
.bind(relation_id)
.fetch_one(&mut *self)
.await?;
vacant.insert(TableColumns {
table_name: table_name.into(),
columns: Default::default(),
})
}
};
let column_name = match table_columns.columns.entry(attribute_no) {
btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()),
btree_map::Entry::Vacant(vacant) => {
if !should_fetch { return Ok(ColumnOrigin::Unknown); }
let column_name: String = query_scalar(
"SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2"
)
.bind(relation_id)
.bind(attribute_no)
.fetch_one(&mut *self)
.await?;
Arc::clone(vacant.insert(column_name.into()))
}
};
Ok(ColumnOrigin::Table(TableColumn {
table: table_columns.table_name.clone(),
name: column_name
}))
}
async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result<PgTypeInfo, Error> {
let (name, typ_type, category, relation_id, element, base_type): (

View File

@ -1,3 +1,4 @@
use std::collections::BTreeMap;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
@ -61,6 +62,7 @@ pub struct PgConnectionInner {
cache_type_info: HashMap<Oid, PgTypeInfo>,
cache_type_oid: HashMap<UStr, Oid>,
cache_elem_type_to_array: HashMap<Oid, Oid>,
cache_table_to_column_names: HashMap<Oid, TableColumns>,
// number of ReadyForQuery messages that we are currently expecting
pub(crate) pending_ready_for_query_count: usize,
@ -72,6 +74,12 @@ pub struct PgConnectionInner {
log_settings: LogSettings,
}
pub(crate) struct TableColumns {
table_name: Arc<str>,
/// Attribute number -> name.
columns: BTreeMap<i16, Arc<str>>,
}
impl PgConnection {
/// the version number of the server in `libpq` format
pub fn server_version_num(&self) -> Option<u32> {

View File

@ -9,6 +9,9 @@ pub struct SqliteColumn {
pub(crate) name: UStr,
pub(crate) ordinal: usize,
pub(crate) type_info: SqliteTypeInfo,
#[cfg_attr(feature = "offline", serde(default))]
pub(crate) origin: ColumnOrigin
}
impl Column for SqliteColumn {
@ -25,4 +28,8 @@ impl Column for SqliteColumn {
fn type_info(&self) -> &SqliteTypeInfo {
&self.type_info
}
fn origin(&self) -> ColumnOrigin {
self.origin.clone()
}
}

View File

@ -49,6 +49,8 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result<Descri
for col in 0..num {
let name = stmt.handle.column_name(col).to_owned();
let origin = stmt.handle.column_origin(col);
let type_info = if let Some(ty) = stmt.handle.column_decltype(col) {
ty
@ -82,6 +84,7 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result<Descri
name: name.into(),
type_info,
ordinal: col,
origin,
});
}
}

View File

@ -6,7 +6,7 @@ use std::ptr;
use std::ptr::NonNull;
use std::slice::from_raw_parts;
use std::str::{from_utf8, from_utf8_unchecked};
use std::sync::Arc;
use libsqlite3_sys::{
sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name,
@ -19,7 +19,7 @@ use libsqlite3_sys::{
sqlite3_value, SQLITE_DONE, SQLITE_LOCKED_SHAREDCACHE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW,
SQLITE_TRANSIENT, SQLITE_UTF8,
};
use sqlx_core::column::{ColumnOrigin, TableColumn};
use crate::error::{BoxDynError, Error};
use crate::type_info::DataType;
use crate::{SqliteError, SqliteTypeInfo};
@ -34,6 +34,9 @@ pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>);
unsafe impl Send for StatementHandle {}
// Most of the getters below allocate internally, and unsynchronized access is undefined.
// unsafe impl !Sync for StatementHandle {}
macro_rules! expect_ret_valid {
($fn_name:ident($($args:tt)*)) => {{
let val = $fn_name($($args)*);
@ -110,6 +113,64 @@ impl StatementHandle {
}
}
pub(crate) fn column_origin(&self, index: usize) -> ColumnOrigin {
if let Some((table, name)) =
self.column_table_name(index).zip(self.column_origin_name(index))
{
let table: Arc<str> = self
.column_db_name(index)
.filter(|&db| db != "main")
.map_or_else(
|| table.into(),
// TODO: check that SQLite returns the names properly quoted if necessary
|db| format!("{db}.{table}").into(),
);
ColumnOrigin::Table(TableColumn {
table,
name: name.into()
})
} else {
ColumnOrigin::Expression
}
}
fn column_db_name(&self, index: usize) -> Option<&str> {
unsafe {
let db_name = sqlite3_column_database_name(self.0.as_ptr(), check_col_idx!(index));
if !db_name.is_null() {
Some(from_utf8_unchecked(CStr::from_ptr(db_name).to_bytes()))
} else {
None
}
}
}
fn column_table_name(&self, index: usize) -> Option<&str> {
unsafe {
let table_name = sqlite3_column_table_name(self.0.as_ptr(), check_col_idx!(index));
if !table_name.is_null() {
Some(from_utf8_unchecked(CStr::from_ptr(table_name).to_bytes()))
} else {
None
}
}
}
fn column_origin_name(&self, index: usize) -> Option<&str> {
unsafe {
let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), check_col_idx!(index));
if !origin_name.is_null() {
Some(from_utf8_unchecked(CStr::from_ptr(origin_name).to_bytes()))
} else {
None
}
}
}
pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo {
SqliteTypeInfo(DataType::from_code(self.column_type(index)))
}

View File

@ -5,6 +5,7 @@ pub use sqlx_core::acquire::Acquire;
pub use sqlx_core::arguments::{Arguments, IntoArguments};
pub use sqlx_core::column::Column;
pub use sqlx_core::column::ColumnIndex;
pub use sqlx_core::column::ColumnOrigin;
pub use sqlx_core::connection::{ConnectOptions, Connection};
pub use sqlx_core::database::{self, Database};
pub use sqlx_core::describe::Describe;