mirror of
https://github.com/launchbadge/sqlx.git
synced 2025-12-29 21:00:54 +00:00
feat(sqlite): track nullable through left joins
This commit is contained in:
parent
0c0dd6936a
commit
b0c430ed18
@ -78,10 +78,11 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>(
|
||||
ty
|
||||
};
|
||||
|
||||
nullable.push(stmt.column_nullable(col)?.or_else(|| {
|
||||
// if we do not *know* if this is nullable, check the EXPLAIN fallback
|
||||
fallback_nullable.get(col).copied().and_then(identity)
|
||||
}));
|
||||
// check explain
|
||||
let col_nullable = stmt.column_nullable(col)?;
|
||||
let exp_nullable = fallback_nullable.get(col).copied().and_then(identity);
|
||||
|
||||
nullable.push(exp_nullable.or(col_nullable));
|
||||
|
||||
columns.push(SqliteColumn {
|
||||
name: name.into(),
|
||||
|
||||
@ -22,6 +22,7 @@ const OP_FUNCTION: &str = "Function";
|
||||
const OP_MOVE: &str = "Move";
|
||||
const OP_COPY: &str = "Copy";
|
||||
const OP_SCOPY: &str = "SCopy";
|
||||
const OP_NULL_ROW: &str = "NullRow";
|
||||
const OP_INT_COPY: &str = "IntCopy";
|
||||
const OP_CAST: &str = "Cast";
|
||||
const OP_STRING8: &str = "String8";
|
||||
@ -77,6 +78,8 @@ pub(super) async fn explain(
|
||||
query: &str,
|
||||
) -> Result<(Vec<SqliteTypeInfo>, Vec<Option<bool>>), Error> {
|
||||
let mut r = HashMap::<i64, DataType>::with_capacity(6);
|
||||
let mut r_cursor = HashMap::<i64, Vec<i64>>::with_capacity(6);
|
||||
|
||||
let mut n = HashMap::<i64, bool>::with_capacity(6);
|
||||
|
||||
let program =
|
||||
@ -87,6 +90,11 @@ pub(super) async fn explain(
|
||||
let mut program_i = 0;
|
||||
let program_size = program.len();
|
||||
|
||||
let mut output = Vec::new();
|
||||
let mut nullable = Vec::new();
|
||||
|
||||
let mut result = None;
|
||||
|
||||
while program_i < program_size {
|
||||
let (_, ref opcode, p1, p2, p3, ref p4) = program[program_i];
|
||||
|
||||
@ -104,9 +112,10 @@ pub(super) async fn explain(
|
||||
}
|
||||
|
||||
OP_COLUMN => {
|
||||
r_cursor.entry(p1).or_default().push(p3);
|
||||
|
||||
// r[p3] = <value of column>
|
||||
r.insert(p3, DataType::Null);
|
||||
n.insert(p3, true);
|
||||
}
|
||||
|
||||
OP_VARIABLE => {
|
||||
@ -117,10 +126,21 @@ pub(super) async fn explain(
|
||||
|
||||
OP_FUNCTION => {
|
||||
// r[p1] = func( _ )
|
||||
if from_utf8(p4).map_err(Error::protocol)? == "last_insert_rowid(0)" {
|
||||
// last_insert_rowid() -> INTEGER
|
||||
r.insert(p3, DataType::Int64);
|
||||
n.insert(p3, false);
|
||||
match from_utf8(p4).map_err(Error::protocol)? {
|
||||
"last_insert_rowid(0)" => {
|
||||
// last_insert_rowid() -> INTEGER
|
||||
r.insert(p3, DataType::Int64);
|
||||
n.insert(p3, n.get(&p3).copied().unwrap_or(false));
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
OP_NULL_ROW => {
|
||||
// all values of cursor X are potentially nullable
|
||||
for column in &r_cursor[&p1] {
|
||||
n.insert(*column, true);
|
||||
}
|
||||
}
|
||||
|
||||
@ -130,7 +150,7 @@ pub(super) async fn explain(
|
||||
if p4.starts_with("count(") {
|
||||
// count(_) -> INTEGER
|
||||
r.insert(p3, DataType::Int64);
|
||||
n.insert(p3, false);
|
||||
n.insert(p3, n.get(&p3).copied().unwrap_or(false));
|
||||
} else if let Some(v) = r.get(&p2).copied() {
|
||||
// r[p3] = AGG ( r[p2] )
|
||||
r.insert(p3, v);
|
||||
@ -150,15 +170,19 @@ pub(super) async fn explain(
|
||||
// r[p2] = r[p1]
|
||||
if let Some(v) = r.get(&p1).copied() {
|
||||
r.insert(p2, v);
|
||||
let val = n.get(&p1).copied().unwrap_or(true);
|
||||
n.insert(p2, val);
|
||||
|
||||
if let Some(null) = n.get(&p1).copied() {
|
||||
n.insert(p2, null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => {
|
||||
// r[p2] = <value of constant>
|
||||
r.insert(p2, opcode_to_type(&opcode));
|
||||
n.insert(p2, false);
|
||||
n.insert(p2, n.get(&p2).copied().unwrap_or(false));
|
||||
|
||||
println!("[x] <?> set column {} as INTEGER", p2);
|
||||
}
|
||||
|
||||
OP_NOT => {
|
||||
@ -194,34 +218,21 @@ pub(super) async fn explain(
|
||||
n.insert(p3, a || b);
|
||||
}
|
||||
|
||||
(None, Some(b)) => {
|
||||
n.insert(p3, b);
|
||||
}
|
||||
|
||||
(Some(a), None) => {
|
||||
n.insert(p3, a);
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
OP_RESULT_ROW => {
|
||||
// output = r[p1 .. p1 + p2]
|
||||
let mut output = Vec::with_capacity(p2 as usize);
|
||||
let mut nullable = Vec::with_capacity(p2 as usize);
|
||||
|
||||
for i in p1..p1 + p2 {
|
||||
output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null)));
|
||||
|
||||
nullable.push(if n.remove(&i).unwrap_or(true) {
|
||||
None
|
||||
} else {
|
||||
Some(false)
|
||||
});
|
||||
// the second time we hit ResultRow we short-circuit and get out
|
||||
if result.is_some() {
|
||||
break;
|
||||
}
|
||||
|
||||
return Ok((output, nullable));
|
||||
// output = r[p1 .. p1 + p2]
|
||||
output.reserve(p2 as usize);
|
||||
nullable.reserve(p2 as usize);
|
||||
|
||||
result = Some(p1..p1 + p2);
|
||||
}
|
||||
|
||||
_ => {
|
||||
@ -233,6 +244,12 @@ pub(super) async fn explain(
|
||||
program_i += 1;
|
||||
}
|
||||
|
||||
// no rows
|
||||
Ok((vec![], vec![]))
|
||||
if let Some(result) = result {
|
||||
for i in result {
|
||||
output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null)));
|
||||
nullable.push(n.remove(&i));
|
||||
}
|
||||
}
|
||||
|
||||
Ok((output, nullable))
|
||||
}
|
||||
|
||||
@ -186,3 +186,36 @@ async fn it_describes_bad_statement() -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_describes_left_join() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Sqlite>().await?;
|
||||
|
||||
let d = conn.describe("select accounts.id from accounts").await?;
|
||||
|
||||
assert_eq!(d.column(0).type_info().name(), "INTEGER");
|
||||
assert_eq!(d.nullable(0), Some(false));
|
||||
|
||||
let d = conn.describe("select tweet.id from accounts left join tweet on owner_id = accounts.id").await?;
|
||||
|
||||
assert_eq!(d.column(0).type_info().name(), "INTEGER");
|
||||
assert_eq!(d.nullable(0), Some(true));
|
||||
|
||||
let d = conn.describe("select tweet.id, accounts.id from accounts left join tweet on owner_id = accounts.id").await?;
|
||||
|
||||
assert_eq!(d.column(0).type_info().name(), "INTEGER");
|
||||
assert_eq!(d.nullable(0), Some(true));
|
||||
|
||||
assert_eq!(d.column(1).type_info().name(), "INTEGER");
|
||||
assert_eq!(d.nullable(1), Some(false));
|
||||
|
||||
let d = conn.describe("select tweet.id, accounts.id from accounts inner join tweet on owner_id = accounts.id").await?;
|
||||
|
||||
assert_eq!(d.column(0).type_info().name(), "INTEGER");
|
||||
assert_eq!(d.nullable(0), Some(false));
|
||||
|
||||
assert_eq!(d.column(1).type_info().name(), "INTEGER");
|
||||
assert_eq!(d.nullable(1), Some(false));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user