diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index 9f45819e..74833757 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -2,6 +2,7 @@ use crate::database::Database; use crate::error::Error; use std::fmt::Debug; +use std::sync::Arc; pub trait Column: 'static + Send + Sync + Debug { type Database: Database; @@ -20,6 +21,59 @@ pub trait Column: 'static + Send + Sync + Debug { /// Gets the type information for the column. fn type_info(&self) -> &::TypeInfo; + + /// If this column comes from a table, return the table and original column name. + /// + /// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression + /// or else the source table could not be determined. + /// + /// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information, + /// or has not overridden this method. + // This method returns an owned value instead of a reference, + // to give the implementor more flexibility. + fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown } +} + +/// A [`Column`] that originates from a table. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct TableColumn { + /// The name of the table (optionally schema-qualified) that the column comes from. + pub table: Arc, + /// The original name of the column. + pub name: Arc, +} + +/// The possible statuses for our knowledge of the origin of a [`Column`]. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub enum ColumnOrigin { + /// The column is known to originate from a table. + /// + /// Included is the table name and original column name. + Table(TableColumn), + /// The column originates from an expression, or else its origin could not be determined. + Expression, + /// The database driver does not know the column origin at this time. + /// + /// This may happen if: + /// * The connection is in the middle of executing a query, + /// and cannot query the catalog to fetch this information. + /// * The connection does not have access to the database catalog. + /// * The implementation of [`Column`] did not override [`Column::origin()`]. + #[default] + Unknown, +} + +impl ColumnOrigin { + /// Returns the true column origin, if known. + pub fn table_column(&self) -> Option<&TableColumn> { + if let Self::Table(table_column) = self { + Some(table_column) + } else { + None + } + } } /// A type that can be used to index into a [`Row`] or [`Statement`]. diff --git a/sqlx-mysql/src/column.rs b/sqlx-mysql/src/column.rs index 1bb841b9..457cf991 100644 --- a/sqlx-mysql/src/column.rs +++ b/sqlx-mysql/src/column.rs @@ -10,6 +10,9 @@ pub struct MySqlColumn { pub(crate) name: UStr, pub(crate) type_info: MySqlTypeInfo, + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin, + #[cfg_attr(feature = "offline", serde(skip))] pub(crate) flags: Option, } @@ -28,4 +31,8 @@ impl Column for MySqlColumn { fn type_info(&self) -> &MySqlTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index 07c7979b..6baad5cc 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -23,6 +23,7 @@ use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; use std::{borrow::Cow, sync::Arc}; +use sqlx_core::column::{ColumnOrigin, TableColumn}; impl MySqlConnection { async fn prepare_statement<'c>( @@ -382,11 +383,30 @@ async fn recv_result_columns( fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result { // if the alias is empty, use the alias // only then use the name + let column_name = def.name()?; + let name = match (def.name()?, def.alias()?) { (_, alias) if !alias.is_empty() => UStr::new(alias), (name, _) => UStr::new(name), }; + let table = def.table()?; + + let origin = if table.is_empty() { + ColumnOrigin::Expression + } else { + let schema = def.schema()?; + + ColumnOrigin::Table(TableColumn { + table: if !schema.is_empty() { + format!("{schema}.{table}").into() + } else { + table.into() + }, + name: column_name.into(), + }) + }; + let type_info = MySqlTypeInfo::from_column(def); Ok(MySqlColumn { @@ -394,6 +414,7 @@ fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result Result<&str, Error> { + str::from_utf8(&self.schema).map_err(Error::protocol) + } + + pub(crate) fn table(&self) -> Result<&str, Error> { + str::from_utf8(&self.table).map_err(Error::protocol) + } + pub(crate) fn name(&self) -> Result<&str, Error> { - from_utf8(&self.name).map_err(Error::protocol) + str::from_utf8(&self.name).map_err(Error::protocol) } pub(crate) fn alias(&self) -> Result<&str, Error> { - from_utf8(&self.alias).map_err(Error::protocol) + str::from_utf8(&self.alias).map_err(Error::protocol) } } diff --git a/sqlx-postgres/src/column.rs b/sqlx-postgres/src/column.rs index a838c27b..bd08e27d 100644 --- a/sqlx-postgres/src/column.rs +++ b/sqlx-postgres/src/column.rs @@ -2,6 +2,7 @@ use crate::ext::ustr::UStr; use crate::{PgTypeInfo, Postgres}; pub(crate) use sqlx_core::column::{Column, ColumnIndex}; +use sqlx_core::column::ColumnOrigin; #[derive(Debug, Clone)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] @@ -9,6 +10,10 @@ pub struct PgColumn { pub(crate) ordinal: usize, pub(crate) name: UStr, pub(crate) type_info: PgTypeInfo, + + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin, + #[cfg_attr(feature = "offline", serde(skip))] pub(crate) relation_id: Option, #[cfg_attr(feature = "offline", serde(skip))] @@ -51,4 +56,8 @@ impl Column for PgColumn { fn type_info(&self) -> &PgTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index a27578c5..53affe5d 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -1,3 +1,4 @@ +use std::collections::btree_map; use crate::error::Error; use crate::ext::ustr::UStr; use crate::io::StatementId; @@ -13,6 +14,9 @@ use crate::{PgColumn, PgConnection, PgTypeInfo}; use smallvec::SmallVec; use sqlx_core::query_builder::QueryBuilder; use std::sync::Arc; +use sqlx_core::column::{ColumnOrigin, TableColumn}; +use sqlx_core::hash_map; +use crate::connection::TableColumns; /// Describes the type of the `pg_type.typtype` column /// @@ -121,6 +125,12 @@ impl PgConnection { let type_info = self .maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch) .await?; + + let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) { + self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await? + } else { + ColumnOrigin::Expression + }; let column = PgColumn { ordinal: index, @@ -128,6 +138,7 @@ impl PgConnection { type_info, relation_id: field.relation_id, relation_attribute_no: field.relation_attribute_no, + origin, }; columns.push(column); @@ -189,6 +200,54 @@ impl PgConnection { Ok(PgTypeInfo(PgType::DeclareWithOid(oid))) } } + + async fn maybe_fetch_column_origin( + &mut self, + relation_id: Oid, + attribute_no: i16, + should_fetch: bool, + ) -> Result { + let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) { + hash_map::Entry::Occupied(table_columns) => { + table_columns.into_mut() + }, + hash_map::Entry::Vacant(vacant) => { + if !should_fetch { return Ok(ColumnOrigin::Unknown); } + + let table_name: String = query_scalar("SELECT $1::oid::regclass::text") + .bind(relation_id) + .fetch_one(&mut *self) + .await?; + + vacant.insert(TableColumns { + table_name: table_name.into(), + columns: Default::default(), + }) + } + }; + + let column_name = match table_columns.columns.entry(attribute_no) { + btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()), + btree_map::Entry::Vacant(vacant) => { + if !should_fetch { return Ok(ColumnOrigin::Unknown); } + + let column_name: String = query_scalar( + "SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2" + ) + .bind(relation_id) + .bind(attribute_no) + .fetch_one(&mut *self) + .await?; + + Arc::clone(vacant.insert(column_name.into())) + } + }; + + Ok(ColumnOrigin::Table(TableColumn { + table: table_columns.table_name.clone(), + name: column_name + })) + } async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result { let (name, typ_type, category, relation_id, element, base_type): ( diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index c139f8e5..3cb9ecaf 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -61,6 +62,7 @@ pub struct PgConnectionInner { cache_type_info: HashMap, cache_type_oid: HashMap, cache_elem_type_to_array: HashMap, + cache_table_to_column_names: HashMap, // number of ReadyForQuery messages that we are currently expecting pub(crate) pending_ready_for_query_count: usize, @@ -72,6 +74,12 @@ pub struct PgConnectionInner { log_settings: LogSettings, } +pub(crate) struct TableColumns { + table_name: Arc, + /// Attribute number -> name. + columns: BTreeMap>, +} + impl PgConnection { /// the version number of the server in `libpq` format pub fn server_version_num(&self) -> Option { diff --git a/sqlx-sqlite/src/column.rs b/sqlx-sqlite/src/column.rs index 00b3bc36..390f3687 100644 --- a/sqlx-sqlite/src/column.rs +++ b/sqlx-sqlite/src/column.rs @@ -9,6 +9,9 @@ pub struct SqliteColumn { pub(crate) name: UStr, pub(crate) ordinal: usize, pub(crate) type_info: SqliteTypeInfo, + + #[cfg_attr(feature = "offline", serde(default))] + pub(crate) origin: ColumnOrigin } impl Column for SqliteColumn { @@ -25,4 +28,8 @@ impl Column for SqliteColumn { fn type_info(&self) -> &SqliteTypeInfo { &self.type_info } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } } diff --git a/sqlx-sqlite/src/connection/describe.rs b/sqlx-sqlite/src/connection/describe.rs index 0f4da33c..9ba9f8c3 100644 --- a/sqlx-sqlite/src/connection/describe.rs +++ b/sqlx-sqlite/src/connection/describe.rs @@ -49,6 +49,8 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result Result); unsafe impl Send for StatementHandle {} +// Most of the getters below allocate internally, and unsynchronized access is undefined. +// unsafe impl !Sync for StatementHandle {} + macro_rules! expect_ret_valid { ($fn_name:ident($($args:tt)*)) => {{ let val = $fn_name($($args)*); @@ -110,6 +113,64 @@ impl StatementHandle { } } + pub(crate) fn column_origin(&self, index: usize) -> ColumnOrigin { + if let Some((table, name)) = + self.column_table_name(index).zip(self.column_origin_name(index)) + { + let table: Arc = self + .column_db_name(index) + .filter(|&db| db != "main") + .map_or_else( + || table.into(), + // TODO: check that SQLite returns the names properly quoted if necessary + |db| format!("{db}.{table}").into(), + ); + + ColumnOrigin::Table(TableColumn { + table, + name: name.into() + }) + } else { + ColumnOrigin::Expression + } + } + + fn column_db_name(&self, index: usize) -> Option<&str> { + unsafe { + let db_name = sqlite3_column_database_name(self.0.as_ptr(), check_col_idx!(index)); + + if !db_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(db_name).to_bytes())) + } else { + None + } + } + } + + fn column_table_name(&self, index: usize) -> Option<&str> { + unsafe { + let table_name = sqlite3_column_table_name(self.0.as_ptr(), check_col_idx!(index)); + + if !table_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(table_name).to_bytes())) + } else { + None + } + } + } + + fn column_origin_name(&self, index: usize) -> Option<&str> { + unsafe { + let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), check_col_idx!(index)); + + if !origin_name.is_null() { + Some(from_utf8_unchecked(CStr::from_ptr(origin_name).to_bytes())) + } else { + None + } + } + } + pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo { SqliteTypeInfo(DataType::from_code(self.column_type(index))) } diff --git a/src/lib.rs b/src/lib.rs index 19142f66..a357753b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub use sqlx_core::acquire::Acquire; pub use sqlx_core::arguments::{Arguments, IntoArguments}; pub use sqlx_core::column::Column; pub use sqlx_core::column::ColumnIndex; +pub use sqlx_core::column::ColumnOrigin; pub use sqlx_core::connection::{ConnectOptions, Connection}; pub use sqlx_core::database::{self, Database}; pub use sqlx_core::describe::Describe;