mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
refactor: add origin information to Column
This commit is contained in:
parent
e775d2a3eb
commit
bf90a477a1
@ -2,6 +2,7 @@ use crate::database::Database;
|
||||
use crate::error::Error;
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait Column: 'static + Send + Sync + Debug {
|
||||
type Database: Database<Column = Self>;
|
||||
@ -20,6 +21,59 @@ pub trait Column: 'static + Send + Sync + Debug {
|
||||
|
||||
/// Gets the type information for the column.
|
||||
fn type_info(&self) -> &<Self::Database as Database>::TypeInfo;
|
||||
|
||||
/// If this column comes from a table, return the table and original column name.
|
||||
///
|
||||
/// Returns [`ColumnOrigin::Expression`] if the column is the result of an expression
|
||||
/// or else the source table could not be determined.
|
||||
///
|
||||
/// Returns [`ColumnOrigin::Unknown`] if the database driver does not have that information,
|
||||
/// or has not overridden this method.
|
||||
// This method returns an owned value instead of a reference,
|
||||
// to give the implementor more flexibility.
|
||||
fn origin(&self) -> ColumnOrigin { ColumnOrigin::Unknown }
|
||||
}
|
||||
|
||||
/// A [`Column`] that originates from a table.
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct TableColumn {
|
||||
/// The name of the table (optionally schema-qualified) that the column comes from.
|
||||
pub table: Arc<str>,
|
||||
/// The original name of the column.
|
||||
pub name: Arc<str>,
|
||||
}
|
||||
|
||||
/// The possible statuses for our knowledge of the origin of a [`Column`].
|
||||
#[derive(Debug, Clone, Default)]
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum ColumnOrigin {
|
||||
/// The column is known to originate from a table.
|
||||
///
|
||||
/// Included is the table name and original column name.
|
||||
Table(TableColumn),
|
||||
/// The column originates from an expression, or else its origin could not be determined.
|
||||
Expression,
|
||||
/// The database driver does not know the column origin at this time.
|
||||
///
|
||||
/// This may happen if:
|
||||
/// * The connection is in the middle of executing a query,
|
||||
/// and cannot query the catalog to fetch this information.
|
||||
/// * The connection does not have access to the database catalog.
|
||||
/// * The implementation of [`Column`] did not override [`Column::origin()`].
|
||||
#[default]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl ColumnOrigin {
|
||||
/// Returns the true column origin, if known.
|
||||
pub fn table_column(&self) -> Option<&TableColumn> {
|
||||
if let Self::Table(table_column) = self {
|
||||
Some(table_column)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A type that can be used to index into a [`Row`] or [`Statement`].
|
||||
|
||||
@ -10,6 +10,9 @@ pub struct MySqlColumn {
|
||||
pub(crate) name: UStr,
|
||||
pub(crate) type_info: MySqlTypeInfo,
|
||||
|
||||
#[cfg_attr(feature = "offline", serde(default))]
|
||||
pub(crate) origin: ColumnOrigin,
|
||||
|
||||
#[cfg_attr(feature = "offline", serde(skip))]
|
||||
pub(crate) flags: Option<ColumnFlags>,
|
||||
}
|
||||
@ -28,4 +31,8 @@ impl Column for MySqlColumn {
|
||||
fn type_info(&self) -> &MySqlTypeInfo {
|
||||
&self.type_info
|
||||
}
|
||||
|
||||
fn origin(&self) -> ColumnOrigin {
|
||||
self.origin.clone()
|
||||
}
|
||||
}
|
||||
|
||||
@ -23,6 +23,7 @@ use futures_core::stream::BoxStream;
|
||||
use futures_core::Stream;
|
||||
use futures_util::{pin_mut, TryStreamExt};
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
use sqlx_core::column::{ColumnOrigin, TableColumn};
|
||||
|
||||
impl MySqlConnection {
|
||||
async fn prepare_statement<'c>(
|
||||
@ -382,11 +383,30 @@ async fn recv_result_columns(
|
||||
fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MySqlColumn, Error> {
|
||||
// if the alias is empty, use the alias
|
||||
// only then use the name
|
||||
let column_name = def.name()?;
|
||||
|
||||
let name = match (def.name()?, def.alias()?) {
|
||||
(_, alias) if !alias.is_empty() => UStr::new(alias),
|
||||
(name, _) => UStr::new(name),
|
||||
};
|
||||
|
||||
let table = def.table()?;
|
||||
|
||||
let origin = if table.is_empty() {
|
||||
ColumnOrigin::Expression
|
||||
} else {
|
||||
let schema = def.schema()?;
|
||||
|
||||
ColumnOrigin::Table(TableColumn {
|
||||
table: if !schema.is_empty() {
|
||||
format!("{schema}.{table}").into()
|
||||
} else {
|
||||
table.into()
|
||||
},
|
||||
name: column_name.into(),
|
||||
})
|
||||
};
|
||||
|
||||
let type_info = MySqlTypeInfo::from_column(def);
|
||||
|
||||
Ok(MySqlColumn {
|
||||
@ -394,6 +414,7 @@ fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MyS
|
||||
type_info,
|
||||
ordinal,
|
||||
flags: Some(def.flags),
|
||||
origin,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
use std::str::from_utf8;
|
||||
use std::str;
|
||||
|
||||
use bitflags::bitflags;
|
||||
use bytes::{Buf, Bytes};
|
||||
@ -104,11 +104,9 @@ pub enum ColumnType {
|
||||
pub(crate) struct ColumnDefinition {
|
||||
#[allow(unused)]
|
||||
catalog: Bytes,
|
||||
#[allow(unused)]
|
||||
schema: Bytes,
|
||||
#[allow(unused)]
|
||||
table_alias: Bytes,
|
||||
#[allow(unused)]
|
||||
table: Bytes,
|
||||
alias: Bytes,
|
||||
name: Bytes,
|
||||
@ -125,12 +123,20 @@ impl ColumnDefinition {
|
||||
// NOTE: strings in-protocol are transmitted according to the client character set
|
||||
// as this is UTF-8, all these strings should be UTF-8
|
||||
|
||||
pub(crate) fn schema(&self) -> Result<&str, Error> {
|
||||
str::from_utf8(&self.schema).map_err(Error::protocol)
|
||||
}
|
||||
|
||||
pub(crate) fn table(&self) -> Result<&str, Error> {
|
||||
str::from_utf8(&self.table).map_err(Error::protocol)
|
||||
}
|
||||
|
||||
pub(crate) fn name(&self) -> Result<&str, Error> {
|
||||
from_utf8(&self.name).map_err(Error::protocol)
|
||||
str::from_utf8(&self.name).map_err(Error::protocol)
|
||||
}
|
||||
|
||||
pub(crate) fn alias(&self) -> Result<&str, Error> {
|
||||
from_utf8(&self.alias).map_err(Error::protocol)
|
||||
str::from_utf8(&self.alias).map_err(Error::protocol)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ use crate::ext::ustr::UStr;
|
||||
use crate::{PgTypeInfo, Postgres};
|
||||
|
||||
pub(crate) use sqlx_core::column::{Column, ColumnIndex};
|
||||
use sqlx_core::column::ColumnOrigin;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
@ -9,6 +10,10 @@ pub struct PgColumn {
|
||||
pub(crate) ordinal: usize,
|
||||
pub(crate) name: UStr,
|
||||
pub(crate) type_info: PgTypeInfo,
|
||||
|
||||
#[cfg_attr(feature = "offline", serde(default))]
|
||||
pub(crate) origin: ColumnOrigin,
|
||||
|
||||
#[cfg_attr(feature = "offline", serde(skip))]
|
||||
pub(crate) relation_id: Option<crate::types::Oid>,
|
||||
#[cfg_attr(feature = "offline", serde(skip))]
|
||||
@ -51,4 +56,8 @@ impl Column for PgColumn {
|
||||
fn type_info(&self) -> &PgTypeInfo {
|
||||
&self.type_info
|
||||
}
|
||||
|
||||
fn origin(&self) -> ColumnOrigin {
|
||||
self.origin.clone()
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use std::collections::btree_map;
|
||||
use crate::error::Error;
|
||||
use crate::ext::ustr::UStr;
|
||||
use crate::io::StatementId;
|
||||
@ -13,6 +14,9 @@ use crate::{PgColumn, PgConnection, PgTypeInfo};
|
||||
use smallvec::SmallVec;
|
||||
use sqlx_core::query_builder::QueryBuilder;
|
||||
use std::sync::Arc;
|
||||
use sqlx_core::column::{ColumnOrigin, TableColumn};
|
||||
use sqlx_core::hash_map;
|
||||
use crate::connection::TableColumns;
|
||||
|
||||
/// Describes the type of the `pg_type.typtype` column
|
||||
///
|
||||
@ -121,6 +125,12 @@ impl PgConnection {
|
||||
let type_info = self
|
||||
.maybe_fetch_type_info_by_oid(field.data_type_id, should_fetch)
|
||||
.await?;
|
||||
|
||||
let origin = if let (Some(relation_oid), Some(attribute_no)) = (field.relation_id, field.relation_attribute_no) {
|
||||
self.maybe_fetch_column_origin(relation_oid, attribute_no, should_fetch).await?
|
||||
} else {
|
||||
ColumnOrigin::Expression
|
||||
};
|
||||
|
||||
let column = PgColumn {
|
||||
ordinal: index,
|
||||
@ -128,6 +138,7 @@ impl PgConnection {
|
||||
type_info,
|
||||
relation_id: field.relation_id,
|
||||
relation_attribute_no: field.relation_attribute_no,
|
||||
origin,
|
||||
};
|
||||
|
||||
columns.push(column);
|
||||
@ -189,6 +200,54 @@ impl PgConnection {
|
||||
Ok(PgTypeInfo(PgType::DeclareWithOid(oid)))
|
||||
}
|
||||
}
|
||||
|
||||
async fn maybe_fetch_column_origin(
|
||||
&mut self,
|
||||
relation_id: Oid,
|
||||
attribute_no: i16,
|
||||
should_fetch: bool,
|
||||
) -> Result<ColumnOrigin, Error> {
|
||||
let mut table_columns = match self.cache_table_to_column_names.entry(relation_id) {
|
||||
hash_map::Entry::Occupied(table_columns) => {
|
||||
table_columns.into_mut()
|
||||
},
|
||||
hash_map::Entry::Vacant(vacant) => {
|
||||
if !should_fetch { return Ok(ColumnOrigin::Unknown); }
|
||||
|
||||
let table_name: String = query_scalar("SELECT $1::oid::regclass::text")
|
||||
.bind(relation_id)
|
||||
.fetch_one(&mut *self)
|
||||
.await?;
|
||||
|
||||
vacant.insert(TableColumns {
|
||||
table_name: table_name.into(),
|
||||
columns: Default::default(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let column_name = match table_columns.columns.entry(attribute_no) {
|
||||
btree_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()),
|
||||
btree_map::Entry::Vacant(vacant) => {
|
||||
if !should_fetch { return Ok(ColumnOrigin::Unknown); }
|
||||
|
||||
let column_name: String = query_scalar(
|
||||
"SELECT attname FROM pg_attribute WHERE attrelid = $1 AND attnum = $2"
|
||||
)
|
||||
.bind(relation_id)
|
||||
.bind(attribute_no)
|
||||
.fetch_one(&mut *self)
|
||||
.await?;
|
||||
|
||||
Arc::clone(vacant.insert(column_name.into()))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ColumnOrigin::Table(TableColumn {
|
||||
table: table_columns.table_name.clone(),
|
||||
name: column_name
|
||||
}))
|
||||
}
|
||||
|
||||
async fn fetch_type_by_oid(&mut self, oid: Oid) -> Result<PgTypeInfo, Error> {
|
||||
let (name, typ_type, category, relation_id, element, base_type): (
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -61,6 +62,7 @@ pub struct PgConnectionInner {
|
||||
cache_type_info: HashMap<Oid, PgTypeInfo>,
|
||||
cache_type_oid: HashMap<UStr, Oid>,
|
||||
cache_elem_type_to_array: HashMap<Oid, Oid>,
|
||||
cache_table_to_column_names: HashMap<Oid, TableColumns>,
|
||||
|
||||
// number of ReadyForQuery messages that we are currently expecting
|
||||
pub(crate) pending_ready_for_query_count: usize,
|
||||
@ -72,6 +74,12 @@ pub struct PgConnectionInner {
|
||||
log_settings: LogSettings,
|
||||
}
|
||||
|
||||
pub(crate) struct TableColumns {
|
||||
table_name: Arc<str>,
|
||||
/// Attribute number -> name.
|
||||
columns: BTreeMap<i16, Arc<str>>,
|
||||
}
|
||||
|
||||
impl PgConnection {
|
||||
/// the version number of the server in `libpq` format
|
||||
pub fn server_version_num(&self) -> Option<u32> {
|
||||
|
||||
@ -9,6 +9,9 @@ pub struct SqliteColumn {
|
||||
pub(crate) name: UStr,
|
||||
pub(crate) ordinal: usize,
|
||||
pub(crate) type_info: SqliteTypeInfo,
|
||||
|
||||
#[cfg_attr(feature = "offline", serde(default))]
|
||||
pub(crate) origin: ColumnOrigin
|
||||
}
|
||||
|
||||
impl Column for SqliteColumn {
|
||||
@ -25,4 +28,8 @@ impl Column for SqliteColumn {
|
||||
fn type_info(&self) -> &SqliteTypeInfo {
|
||||
&self.type_info
|
||||
}
|
||||
|
||||
fn origin(&self) -> ColumnOrigin {
|
||||
self.origin.clone()
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,6 +49,8 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result<Descri
|
||||
|
||||
for col in 0..num {
|
||||
let name = stmt.handle.column_name(col).to_owned();
|
||||
|
||||
let origin = stmt.handle.column_origin(col);
|
||||
|
||||
let type_info = if let Some(ty) = stmt.handle.column_decltype(col) {
|
||||
ty
|
||||
@ -82,6 +84,7 @@ pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result<Descri
|
||||
name: name.into(),
|
||||
type_info,
|
||||
ordinal: col,
|
||||
origin,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ use std::ptr;
|
||||
use std::ptr::NonNull;
|
||||
use std::slice::from_raw_parts;
|
||||
use std::str::{from_utf8, from_utf8_unchecked};
|
||||
|
||||
use std::sync::Arc;
|
||||
use libsqlite3_sys::{
|
||||
sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64,
|
||||
sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name,
|
||||
@ -19,7 +19,7 @@ use libsqlite3_sys::{
|
||||
sqlite3_value, SQLITE_DONE, SQLITE_LOCKED_SHAREDCACHE, SQLITE_MISUSE, SQLITE_OK, SQLITE_ROW,
|
||||
SQLITE_TRANSIENT, SQLITE_UTF8,
|
||||
};
|
||||
|
||||
use sqlx_core::column::{ColumnOrigin, TableColumn};
|
||||
use crate::error::{BoxDynError, Error};
|
||||
use crate::type_info::DataType;
|
||||
use crate::{SqliteError, SqliteTypeInfo};
|
||||
@ -34,6 +34,9 @@ pub(crate) struct StatementHandle(NonNull<sqlite3_stmt>);
|
||||
|
||||
unsafe impl Send for StatementHandle {}
|
||||
|
||||
// Most of the getters below allocate internally, and unsynchronized access is undefined.
|
||||
// unsafe impl !Sync for StatementHandle {}
|
||||
|
||||
macro_rules! expect_ret_valid {
|
||||
($fn_name:ident($($args:tt)*)) => {{
|
||||
let val = $fn_name($($args)*);
|
||||
@ -110,6 +113,64 @@ impl StatementHandle {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn column_origin(&self, index: usize) -> ColumnOrigin {
|
||||
if let Some((table, name)) =
|
||||
self.column_table_name(index).zip(self.column_origin_name(index))
|
||||
{
|
||||
let table: Arc<str> = self
|
||||
.column_db_name(index)
|
||||
.filter(|&db| db != "main")
|
||||
.map_or_else(
|
||||
|| table.into(),
|
||||
// TODO: check that SQLite returns the names properly quoted if necessary
|
||||
|db| format!("{db}.{table}").into(),
|
||||
);
|
||||
|
||||
ColumnOrigin::Table(TableColumn {
|
||||
table,
|
||||
name: name.into()
|
||||
})
|
||||
} else {
|
||||
ColumnOrigin::Expression
|
||||
}
|
||||
}
|
||||
|
||||
fn column_db_name(&self, index: usize) -> Option<&str> {
|
||||
unsafe {
|
||||
let db_name = sqlite3_column_database_name(self.0.as_ptr(), check_col_idx!(index));
|
||||
|
||||
if !db_name.is_null() {
|
||||
Some(from_utf8_unchecked(CStr::from_ptr(db_name).to_bytes()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn column_table_name(&self, index: usize) -> Option<&str> {
|
||||
unsafe {
|
||||
let table_name = sqlite3_column_table_name(self.0.as_ptr(), check_col_idx!(index));
|
||||
|
||||
if !table_name.is_null() {
|
||||
Some(from_utf8_unchecked(CStr::from_ptr(table_name).to_bytes()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn column_origin_name(&self, index: usize) -> Option<&str> {
|
||||
unsafe {
|
||||
let origin_name = sqlite3_column_origin_name(self.0.as_ptr(), check_col_idx!(index));
|
||||
|
||||
if !origin_name.is_null() {
|
||||
Some(from_utf8_unchecked(CStr::from_ptr(origin_name).to_bytes()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo {
|
||||
SqliteTypeInfo(DataType::from_code(self.column_type(index)))
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ pub use sqlx_core::acquire::Acquire;
|
||||
pub use sqlx_core::arguments::{Arguments, IntoArguments};
|
||||
pub use sqlx_core::column::Column;
|
||||
pub use sqlx_core::column::ColumnIndex;
|
||||
pub use sqlx_core::column::ColumnOrigin;
|
||||
pub use sqlx_core::connection::{ConnectOptions, Connection};
|
||||
pub use sqlx_core::database::{self, Database};
|
||||
pub use sqlx_core::describe::Describe;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user