diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index 45be0feb..6b8ee0ce 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -67,7 +67,7 @@ const OP_SEEK_GE: &str = "SeekGE"; const OP_SEEK_GT: &str = "SeekGT"; const OP_SEEK_LE: &str = "SeekLE"; const OP_SEEK_LT: &str = "SeekLT"; -const OP_SEEK_ROW_ID: &str = "SeekRowId"; +const OP_SEEK_ROW_ID: &str = "SeekRowid"; const OP_SEEK_SCAN: &str = "SeekScan"; const OP_SEQUENCE: &str = "Sequence"; const OP_SEQUENCE_TEST: &str = "SequenceTest"; @@ -85,6 +85,7 @@ const OP_COLUMN: &str = "Column"; const OP_MAKE_RECORD: &str = "MakeRecord"; const OP_INSERT: &str = "Insert"; const OP_IDX_INSERT: &str = "IdxInsert"; +const OP_OPEN_DUP: &str = "OpenDup"; const OP_OPEN_PSEUDO: &str = "OpenPseudo"; const OP_OPEN_READ: &str = "OpenRead"; const OP_OPEN_WRITE: &str = "OpenWrite"; @@ -200,34 +201,35 @@ impl RegDataType { } } +impl Default for RegDataType { + fn default() -> Self { + Self::Single(ColumnType::default()) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +struct TableDataType { + cols: IntMap, + is_empty: Option, +} + #[derive(Debug, Clone, Eq, PartialEq, Hash)] enum CursorDataType { - Normal { - cols: IntMap, - - is_empty: Option, - }, + Normal(i64), Pseudo(i64), } impl CursorDataType { - fn from_intmap(record: &IntMap, is_empty: Option) -> Self { - Self::Normal { - cols: record.clone(), - is_empty, - } - } - - fn from_dense_record(record: &Vec, is_empty: Option) -> Self { - Self::Normal { - cols: IntMap::from_dense_record(record), - is_empty, - } - } - - fn map_to_intmap(&self, registers: &IntMap) -> IntMap { + fn columns( + &self, + tables: &IntMap, + registers: &IntMap, + ) -> IntMap { match self { - Self::Normal { cols, .. } => cols.clone(), + Self::Normal(i) => match tables.get(i) { + Some(tab) => tab.cols.clone(), + None => IntMap::new(), + }, Self::Pseudo(i) => match registers.get(i) { Some(RegDataType::Single(ColumnType::Record(r))) => r.clone(), _ => IntMap::new(), @@ -235,9 +237,71 @@ impl CursorDataType { } } - fn is_empty(&self) -> Option { + fn columns_ref<'s, 'r, 'o>( + &'s self, + tables: &'r IntMap, + registers: &'r IntMap, + ) -> Option<&'o IntMap> + where + 's: 'o, + 'r: 'o, + { match self { - Self::Normal { is_empty, .. } => *is_empty, + Self::Normal(i) => match tables.get(i) { + Some(tab) => Some(&tab.cols), + None => None, + }, + Self::Pseudo(i) => match registers.get(i) { + Some(RegDataType::Single(ColumnType::Record(r))) => Some(r), + _ => None, + }, + } + } + + fn columns_mut<'s, 'r, 'o>( + &'s self, + tables: &'r mut IntMap, + registers: &'r mut IntMap, + ) -> Option<&'o mut IntMap> + where + 's: 'o, + 'r: 'o, + { + match self { + Self::Normal(i) => match tables.get_mut(i) { + Some(tab) => Some(&mut tab.cols), + None => None, + }, + Self::Pseudo(i) => match registers.get_mut(i) { + Some(RegDataType::Single(ColumnType::Record(r))) => Some(r), + _ => None, + }, + } + } + + fn table_mut<'s, 'r, 'o>( + &'s self, + tables: &'r mut IntMap, + ) -> Option<&'o mut TableDataType> + where + 's: 'o, + 'r: 'o, + { + match self { + Self::Normal(i) => match tables.get_mut(i) { + Some(tab) => Some(tab), + None => None, + }, + _ => None, + } + } + + fn is_empty(&self, tables: &IntMap) -> Option { + match self { + Self::Normal(i) => match tables.get(i) { + Some(tab) => tab.is_empty, + None => Some(true), + }, Self::Pseudo(_) => Some(false), //pseudo cursors have exactly one row } } @@ -332,6 +396,8 @@ struct MemoryState { pub r: IntMap, // Rows that pointers point to pub p: IntMap, + // Table definitions pointed to by pointers + pub t: IntMap, } struct BranchList { @@ -380,6 +446,7 @@ pub(super) fn explain( mem: MemoryState { program_i: 0, r: IntMap::new(), + t: IntMap::new(), p: IntMap::new(), }, }); @@ -602,21 +669,21 @@ pub(super) fn explain( } if let Some(cursor) = state.mem.p.get(&p1) { - if matches!(cursor.is_empty(), None | Some(true)) { + if matches!(cursor.is_empty(&state.mem.t), None | Some(true)) { //only take this branch if the cursor is empty let mut branch_state = state.clone(); branch_state.mem.program_i = p2 as usize; - if let Some(CursorDataType::Normal { is_empty, .. }) = - branch_state.mem.p.get_mut(&p1) - { - *is_empty = Some(true); + if let Some(cur) = branch_state.mem.p.get(&p1) { + if let Some(tab) = cur.table_mut(&mut branch_state.mem.t) { + tab.is_empty = Some(true); + } } states.push(branch_state); } - if matches!(cursor.is_empty(), None | Some(false)) { + if matches!(cursor.is_empty(&state.mem.t), None | Some(false)) { //only take this branch if the cursor is non-empty state.mem.program_i += 1; continue; @@ -756,24 +823,17 @@ pub(super) fn explain( OP_COLUMN => { //Get the row stored at p1, or NULL; get the column stored at p2, or NULL - if let Some(record) = - state.mem.p.get(&p1).map(|c| c.map_to_intmap(&state.mem.r)) - { - if let Some(col) = record.get(&p2) { - // insert into p3 the datatype of the col - state.mem.r.insert(p3, RegDataType::Single(col.clone())); - } else { - state - .mem - .r - .insert(p3, RegDataType::Single(ColumnType::default())); - } - } else { - state - .mem - .r - .insert(p3, RegDataType::Single(ColumnType::default())); - } + let value: ColumnType = state + .mem + .p + .get(&p1) + .and_then(|c| c.columns_ref(&state.mem.t, &state.mem.r)) + .and_then(|cc| cc.get(&p2)) + .cloned() + .unwrap_or_else(|| ColumnType::default()); + + // insert into p3 the datatype of the col + state.mem.r.insert(p3, RegDataType::Single(value)); } OP_SEQUENCE => { @@ -791,12 +851,16 @@ pub(super) fn explain( OP_ROW_DATA | OP_SORTER_DATA => { //Get entire row from cursor p1, store it into register p2 - if let Some(record) = state.mem.p.get(&p1) { - let rowdata = record.map_to_intmap(&state.mem.r); + if let Some(record) = state + .mem + .p + .get(&p1) + .map(|c| c.columns(&state.mem.t, &state.mem.r)) + { state .mem .r - .insert(p2, RegDataType::Single(ColumnType::Record(rowdata))); + .insert(p2, RegDataType::Single(ColumnType::Record(record))); } else { state .mem @@ -814,7 +878,7 @@ pub(super) fn explain( .mem .r .get(®) - .map(|d| d.clone().map_to_columntype()) + .map(|d| d.map_to_columntype()) .unwrap_or(ColumnType::default()), ); } @@ -828,8 +892,11 @@ pub(super) fn explain( if let Some(RegDataType::Single(ColumnType::Record(record))) = state.mem.r.get(&p2) { - if let Some(CursorDataType::Normal { cols, is_empty }) = - state.mem.p.get_mut(&p1) + if let Some(TableDataType { cols, is_empty }) = state + .mem + .p + .get(&p1) + .and_then(|cur| cur.table_mut(&mut state.mem.t)) { // Insert the record into wherever pointer p1 is *cols = record.clone(); @@ -841,7 +908,11 @@ pub(super) fn explain( OP_DELETE => { // delete a record from cursor p1 - if let Some(CursorDataType::Normal { is_empty, .. }) = state.mem.p.get_mut(&p1) + if let Some(TableDataType { is_empty, .. }) = state + .mem + .p + .get(&p1) + .and_then(|cur| cur.table_mut(&mut state.mem.t)) { if *is_empty == Some(false) { *is_empty = None; //the cursor might be empty now @@ -854,43 +925,52 @@ pub(super) fn explain( state.mem.p.insert(p1, CursorDataType::Pseudo(p2)); } + OP_OPEN_DUP => { + if let Some(cur) = state.mem.p.get(&p2) { + state.mem.p.insert(p1, cur.clone()); + } + } + OP_OPEN_READ | OP_OPEN_WRITE => { //Create a new pointer which is referenced by p1, take column metadata from db schema if found - if p3 == 0 || p3 == 1 { + let table_info = if p3 == 0 || p3 == 1 { if let Some(columns) = root_block_cols.get(&(p3, p2)) { - state - .mem - .p - .insert(p1, CursorDataType::from_intmap(columns, None)); + TableDataType { + cols: columns.clone(), + is_empty: None, + } } else { - state.mem.p.insert( - p1, - CursorDataType::Normal { - cols: IntMap::new(), - is_empty: None, - }, - ); - } - } else { - state.mem.p.insert( - p1, - CursorDataType::Normal { + TableDataType { cols: IntMap::new(), is_empty: None, - }, - ); - } + } + } + } else { + TableDataType { + cols: IntMap::new(), + is_empty: None, + } + }; + + state.mem.t.insert(state.mem.program_i as i64, table_info); + state + .mem + .p + .insert(p1, CursorDataType::Normal(state.mem.program_i as i64)); } OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX | OP_SORTER_OPEN => { //Create a new pointer which is referenced by p1 - state.mem.p.insert( - p1, - CursorDataType::from_dense_record( - &vec![ColumnType::null(); p2 as usize], - Some(true), - ), - ); + let table_info = TableDataType { + cols: IntMap::from_dense_record(&vec![ColumnType::null(); p2 as usize]), + is_empty: Some(true), + }; + + state.mem.t.insert(state.mem.program_i as i64, table_info); + state + .mem + .p + .insert(p1, CursorDataType::Normal(state.mem.program_i as i64)); } OP_VARIABLE => { @@ -912,7 +992,7 @@ pub(super) fn explain( } OP_FUNCTION => { - // r[p1] = func( _ ) + // r[p3] = func( _ ), registered function name is in p4 match from_utf8(p4).map_err(Error::protocol)? { "last_insert_rowid(0)" => { // last_insert_rowid() -> INTEGER @@ -961,8 +1041,11 @@ pub(super) fn explain( OP_NULL_ROW => { // all columns in cursor X are potentially nullable - if let Some(CursorDataType::Normal { ref mut cols, .. }) = - state.mem.p.get_mut(&p1) + if let Some(cols) = state + .mem + .p + .get_mut(&p1) + .and_then(|c| c.columns_mut(&mut state.mem.t, &mut state.mem.r)) { for col in cols.values_mut() { if let ColumnType::Single { @@ -994,6 +1077,15 @@ pub(super) fn explain( nullable: Some(false), }), ); + } else if p4.starts_with("percent_rank(") || p4.starts_with("cume_dist") { + // percent_rank(_) -> REAL + state.mem.r.insert( + p3, + RegDataType::Single(ColumnType::Single { + datatype: DataType::Float, + nullable: Some(false), + }), + ); } else if p4.starts_with("sum(") { if let Some(r_p2) = state.mem.r.get(&p2) { let datatype = match r_p2.map_to_datatype() { @@ -1008,6 +1100,17 @@ pub(super) fn explain( RegDataType::Single(ColumnType::Single { datatype, nullable }), ); } + } else if p4.starts_with("lead(") || p4.starts_with("lag(") { + if let Some(r_p2) = state.mem.r.get(&p2) { + let datatype = r_p2.map_to_datatype(); + state.mem.r.insert( + p3, + RegDataType::Single(ColumnType::Single { + datatype, + nullable: Some(true), + }), + ); + } } else if let Some(v) = state.mem.r.get(&p2).cloned() { // r[p3] = AGG ( r[p2] ) state.mem.r.insert(p3, v); @@ -1031,6 +1134,26 @@ pub(super) fn explain( nullable: Some(false), }), ); + } else if p4.starts_with("percent_rank(") || p4.starts_with("cume_dist") { + // percent_rank(_) -> REAL + state.mem.r.insert( + p3, + RegDataType::Single(ColumnType::Single { + datatype: DataType::Float, + nullable: Some(false), + }), + ); + } else if p4.starts_with("lead(") || p4.starts_with("lag(") { + if let Some(r_p2) = state.mem.r.get(&p2) { + let datatype = r_p2.map_to_datatype(); + state.mem.r.insert( + p3, + RegDataType::Single(ColumnType::Single { + datatype, + nullable: Some(true), + }), + ); + } } } @@ -1119,48 +1242,32 @@ pub(super) fn explain( OP_OR | OP_AND | OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT | OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => { // r[p3] = r[p1] + r[p2] - match (state.mem.r.get(&p1).cloned(), state.mem.r.get(&p2).cloned()) { - (Some(a), Some(b)) => { - state.mem.r.insert( - p3, - RegDataType::Single(ColumnType::Single { - datatype: if matches!(a.map_to_datatype(), DataType::Null) { - b.map_to_datatype() - } else { - a.map_to_datatype() - }, - nullable: match (a.map_to_nullable(), b.map_to_nullable()) { - (Some(a_n), Some(b_n)) => Some(a_n | b_n), - (Some(a_n), None) => Some(a_n), - (None, Some(b_n)) => Some(b_n), - (None, None) => None, - }, - }), - ); - } + let value = match (state.mem.r.get(&p1), state.mem.r.get(&p2)) { + (Some(a), Some(b)) => RegDataType::Single(ColumnType::Single { + datatype: if matches!(a.map_to_datatype(), DataType::Null) { + b.map_to_datatype() + } else { + a.map_to_datatype() + }, + nullable: match (a.map_to_nullable(), b.map_to_nullable()) { + (Some(a_n), Some(b_n)) => Some(a_n | b_n), + (Some(a_n), None) => Some(a_n), + (None, Some(b_n)) => Some(b_n), + (None, None) => None, + }, + }), + (Some(v), None) => RegDataType::Single(ColumnType::Single { + datatype: v.map_to_datatype(), + nullable: None, + }), + (None, Some(v)) => RegDataType::Single(ColumnType::Single { + datatype: v.map_to_datatype(), + nullable: None, + }), + _ => RegDataType::default(), + }; - (Some(v), None) => { - state.mem.r.insert( - p3, - RegDataType::Single(ColumnType::Single { - datatype: v.map_to_datatype(), - nullable: None, - }), - ); - } - - (None, Some(v)) => { - state.mem.r.insert( - p3, - RegDataType::Single(ColumnType::Single { - datatype: v.map_to_datatype(), - nullable: None, - }), - ); - } - - _ => {} - } + state.mem.r.insert(p3, value); } OP_OFFSET_LIMIT => { diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 51fe58de..c10f16f4 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -865,3 +865,106 @@ async fn it_describes_with_recursive() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_describes_analytical_function() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn + .describe("select row_number() over () from accounts") + .await?; + dbg!(&d); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + let d = conn.describe("select rank() over () from accounts").await?; + dbg!(&d); + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + let d = conn + .describe("select dense_rank() over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + let d = conn + .describe("select percent_rank() over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "REAL"); + assert_eq!(d.nullable(0), Some(false)); + + let d = conn + .describe("select cume_dist() over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "REAL"); + assert_eq!(d.nullable(0), Some(false)); + + let d = conn + .describe("select ntile(1) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + let d = conn + .describe("select lag(id) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(true)); + + let d = conn + .describe("select lag(name) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "TEXT"); + assert_eq!(d.nullable(0), Some(true)); + + let d = conn + .describe("select lead(id) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(true)); + + let d = conn + .describe("select lead(name) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "TEXT"); + assert_eq!(d.nullable(0), Some(true)); + + let d = conn + .describe("select first_value(id) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(true)); + + let d = conn + .describe("select first_value(name) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "TEXT"); + assert_eq!(d.nullable(0), Some(true)); + + let d = conn + .describe("select last_value(id) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + let d = conn + .describe("select first_value(name) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "TEXT"); + //assert_eq!(d.nullable(0), Some(false)); //this should be null, but it's hard to prove that it will be + + let d = conn + .describe("select nth_value(id,10) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(true)); + + let d = conn + .describe("select nth_value(name,10) over () from accounts") + .await?; + assert_eq!(d.column(0).type_info().name(), "TEXT"); + assert_eq!(d.nullable(0), Some(true)); + + Ok(()) +}