diff --git a/sqlx-core/src/sqlite/connection/describe.rs b/sqlx-core/src/sqlite/connection/describe.rs index 97e74efa..8bc9f9ce 100644 --- a/sqlx-core/src/sqlite/connection/describe.rs +++ b/sqlx-core/src/sqlite/connection/describe.rs @@ -78,10 +78,11 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( ty }; - nullable.push(stmt.column_nullable(col)?.or_else(|| { - // if we do not *know* if this is nullable, check the EXPLAIN fallback - fallback_nullable.get(col).copied().and_then(identity) - })); + // check explain + let col_nullable = stmt.column_nullable(col)?; + let exp_nullable = fallback_nullable.get(col).copied().and_then(identity); + + nullable.push(exp_nullable.or(col_nullable)); columns.push(SqliteColumn { name: name.into(), diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index 23dae1e3..376bfa9e 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -22,6 +22,7 @@ const OP_FUNCTION: &str = "Function"; const OP_MOVE: &str = "Move"; const OP_COPY: &str = "Copy"; const OP_SCOPY: &str = "SCopy"; +const OP_NULL_ROW: &str = "NullRow"; const OP_INT_COPY: &str = "IntCopy"; const OP_CAST: &str = "Cast"; const OP_STRING8: &str = "String8"; @@ -77,6 +78,8 @@ pub(super) async fn explain( query: &str, ) -> Result<(Vec, Vec>), Error> { let mut r = HashMap::::with_capacity(6); + let mut r_cursor = HashMap::>::with_capacity(6); + let mut n = HashMap::::with_capacity(6); let program = @@ -87,6 +90,11 @@ pub(super) async fn explain( let mut program_i = 0; let program_size = program.len(); + let mut output = Vec::new(); + let mut nullable = Vec::new(); + + let mut result = None; + while program_i < program_size { let (_, ref opcode, p1, p2, p3, ref p4) = program[program_i]; @@ -104,9 +112,10 @@ pub(super) async fn explain( } OP_COLUMN => { + r_cursor.entry(p1).or_default().push(p3); + // r[p3] = r.insert(p3, DataType::Null); - n.insert(p3, true); } OP_VARIABLE => { @@ -117,10 +126,21 @@ pub(super) async fn explain( OP_FUNCTION => { // r[p1] = func( _ ) - if from_utf8(p4).map_err(Error::protocol)? == "last_insert_rowid(0)" { - // last_insert_rowid() -> INTEGER - r.insert(p3, DataType::Int64); - n.insert(p3, false); + match from_utf8(p4).map_err(Error::protocol)? { + "last_insert_rowid(0)" => { + // last_insert_rowid() -> INTEGER + r.insert(p3, DataType::Int64); + n.insert(p3, n.get(&p3).copied().unwrap_or(false)); + } + + _ => {} + } + } + + OP_NULL_ROW => { + // all values of cursor X are potentially nullable + for column in &r_cursor[&p1] { + n.insert(*column, true); } } @@ -130,7 +150,7 @@ pub(super) async fn explain( if p4.starts_with("count(") { // count(_) -> INTEGER r.insert(p3, DataType::Int64); - n.insert(p3, false); + n.insert(p3, n.get(&p3).copied().unwrap_or(false)); } else if let Some(v) = r.get(&p2).copied() { // r[p3] = AGG ( r[p2] ) r.insert(p3, v); @@ -150,15 +170,19 @@ pub(super) async fn explain( // r[p2] = r[p1] if let Some(v) = r.get(&p1).copied() { r.insert(p2, v); - let val = n.get(&p1).copied().unwrap_or(true); - n.insert(p2, val); + + if let Some(null) = n.get(&p1).copied() { + n.insert(p2, null); + } } } OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => { // r[p2] = r.insert(p2, opcode_to_type(&opcode)); - n.insert(p2, false); + n.insert(p2, n.get(&p2).copied().unwrap_or(false)); + + println!("[x] set column {} as INTEGER", p2); } OP_NOT => { @@ -194,34 +218,21 @@ pub(super) async fn explain( n.insert(p3, a || b); } - (None, Some(b)) => { - n.insert(p3, b); - } - - (Some(a), None) => { - n.insert(p3, a); - } - _ => {} } } OP_RESULT_ROW => { - // output = r[p1 .. p1 + p2] - let mut output = Vec::with_capacity(p2 as usize); - let mut nullable = Vec::with_capacity(p2 as usize); - - for i in p1..p1 + p2 { - output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null))); - - nullable.push(if n.remove(&i).unwrap_or(true) { - None - } else { - Some(false) - }); + // the second time we hit ResultRow we short-circuit and get out + if result.is_some() { + break; } - return Ok((output, nullable)); + // output = r[p1 .. p1 + p2] + output.reserve(p2 as usize); + nullable.reserve(p2 as usize); + + result = Some(p1..p1 + p2); } _ => { @@ -233,6 +244,12 @@ pub(super) async fn explain( program_i += 1; } - // no rows - Ok((vec![], vec![])) + if let Some(result) = result { + for i in result { + output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null))); + nullable.push(n.remove(&i)); + } + } + + Ok((output, nullable)) } diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 1583a9f5..b57c151e 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -186,3 +186,36 @@ async fn it_describes_bad_statement() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_describes_left_join() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let d = conn.describe("select accounts.id from accounts").await?; + + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + let d = conn.describe("select tweet.id from accounts left join tweet on owner_id = accounts.id").await?; + + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(true)); + + let d = conn.describe("select tweet.id, accounts.id from accounts left join tweet on owner_id = accounts.id").await?; + + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(true)); + + assert_eq!(d.column(1).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(1), Some(false)); + + let d = conn.describe("select tweet.id, accounts.id from accounts inner join tweet on owner_id = accounts.id").await?; + + assert_eq!(d.column(0).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(0), Some(false)); + + assert_eq!(d.column(1).type_info().name(), "INTEGER"); + assert_eq!(d.nullable(1), Some(false)); + + Ok(()) +}