feat(sqlite): track nullable through left joins

This commit is contained in:
Ryan Leckey 2020-07-27 03:11:31 -07:00
parent 0c0dd6936a
commit b0c430ed18
No known key found for this signature in database
GPG Key ID: F8AA68C235AB08C9
3 changed files with 87 additions and 36 deletions

View File

@ -78,10 +78,11 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>(
ty
};
nullable.push(stmt.column_nullable(col)?.or_else(|| {
// if we do not *know* if this is nullable, check the EXPLAIN fallback
fallback_nullable.get(col).copied().and_then(identity)
}));
// check explain
let col_nullable = stmt.column_nullable(col)?;
let exp_nullable = fallback_nullable.get(col).copied().and_then(identity);
nullable.push(exp_nullable.or(col_nullable));
columns.push(SqliteColumn {
name: name.into(),

View File

@ -22,6 +22,7 @@ const OP_FUNCTION: &str = "Function";
const OP_MOVE: &str = "Move";
const OP_COPY: &str = "Copy";
const OP_SCOPY: &str = "SCopy";
const OP_NULL_ROW: &str = "NullRow";
const OP_INT_COPY: &str = "IntCopy";
const OP_CAST: &str = "Cast";
const OP_STRING8: &str = "String8";
@ -77,6 +78,8 @@ pub(super) async fn explain(
query: &str,
) -> Result<(Vec<SqliteTypeInfo>, Vec<Option<bool>>), Error> {
let mut r = HashMap::<i64, DataType>::with_capacity(6);
let mut r_cursor = HashMap::<i64, Vec<i64>>::with_capacity(6);
let mut n = HashMap::<i64, bool>::with_capacity(6);
let program =
@ -87,6 +90,11 @@ pub(super) async fn explain(
let mut program_i = 0;
let program_size = program.len();
let mut output = Vec::new();
let mut nullable = Vec::new();
let mut result = None;
while program_i < program_size {
let (_, ref opcode, p1, p2, p3, ref p4) = program[program_i];
@ -104,9 +112,10 @@ pub(super) async fn explain(
}
OP_COLUMN => {
r_cursor.entry(p1).or_default().push(p3);
// r[p3] = <value of column>
r.insert(p3, DataType::Null);
n.insert(p3, true);
}
OP_VARIABLE => {
@ -117,10 +126,21 @@ pub(super) async fn explain(
OP_FUNCTION => {
// r[p1] = func( _ )
if from_utf8(p4).map_err(Error::protocol)? == "last_insert_rowid(0)" {
// last_insert_rowid() -> INTEGER
r.insert(p3, DataType::Int64);
n.insert(p3, false);
match from_utf8(p4).map_err(Error::protocol)? {
"last_insert_rowid(0)" => {
// last_insert_rowid() -> INTEGER
r.insert(p3, DataType::Int64);
n.insert(p3, n.get(&p3).copied().unwrap_or(false));
}
_ => {}
}
}
OP_NULL_ROW => {
// all values of cursor X are potentially nullable
for column in &r_cursor[&p1] {
n.insert(*column, true);
}
}
@ -130,7 +150,7 @@ pub(super) async fn explain(
if p4.starts_with("count(") {
// count(_) -> INTEGER
r.insert(p3, DataType::Int64);
n.insert(p3, false);
n.insert(p3, n.get(&p3).copied().unwrap_or(false));
} else if let Some(v) = r.get(&p2).copied() {
// r[p3] = AGG ( r[p2] )
r.insert(p3, v);
@ -150,15 +170,19 @@ pub(super) async fn explain(
// r[p2] = r[p1]
if let Some(v) = r.get(&p1).copied() {
r.insert(p2, v);
let val = n.get(&p1).copied().unwrap_or(true);
n.insert(p2, val);
if let Some(null) = n.get(&p1).copied() {
n.insert(p2, null);
}
}
}
OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => {
// r[p2] = <value of constant>
r.insert(p2, opcode_to_type(&opcode));
n.insert(p2, false);
n.insert(p2, n.get(&p2).copied().unwrap_or(false));
println!("[x] <?> set column {} as INTEGER", p2);
}
OP_NOT => {
@ -194,34 +218,21 @@ pub(super) async fn explain(
n.insert(p3, a || b);
}
(None, Some(b)) => {
n.insert(p3, b);
}
(Some(a), None) => {
n.insert(p3, a);
}
_ => {}
}
}
OP_RESULT_ROW => {
// output = r[p1 .. p1 + p2]
let mut output = Vec::with_capacity(p2 as usize);
let mut nullable = Vec::with_capacity(p2 as usize);
for i in p1..p1 + p2 {
output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null)));
nullable.push(if n.remove(&i).unwrap_or(true) {
None
} else {
Some(false)
});
// the second time we hit ResultRow we short-circuit and get out
if result.is_some() {
break;
}
return Ok((output, nullable));
// output = r[p1 .. p1 + p2]
output.reserve(p2 as usize);
nullable.reserve(p2 as usize);
result = Some(p1..p1 + p2);
}
_ => {
@ -233,6 +244,12 @@ pub(super) async fn explain(
program_i += 1;
}
// no rows
Ok((vec![], vec![]))
if let Some(result) = result {
for i in result {
output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null)));
nullable.push(n.remove(&i));
}
}
Ok((output, nullable))
}

View File

@ -186,3 +186,36 @@ async fn it_describes_bad_statement() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn it_describes_left_join() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let d = conn.describe("select accounts.id from accounts").await?;
assert_eq!(d.column(0).type_info().name(), "INTEGER");
assert_eq!(d.nullable(0), Some(false));
let d = conn.describe("select tweet.id from accounts left join tweet on owner_id = accounts.id").await?;
assert_eq!(d.column(0).type_info().name(), "INTEGER");
assert_eq!(d.nullable(0), Some(true));
let d = conn.describe("select tweet.id, accounts.id from accounts left join tweet on owner_id = accounts.id").await?;
assert_eq!(d.column(0).type_info().name(), "INTEGER");
assert_eq!(d.nullable(0), Some(true));
assert_eq!(d.column(1).type_info().name(), "INTEGER");
assert_eq!(d.nullable(1), Some(false));
let d = conn.describe("select tweet.id, accounts.id from accounts inner join tweet on owner_id = accounts.id").await?;
assert_eq!(d.column(0).type_info().name(), "INTEGER");
assert_eq!(d.nullable(0), Some(false));
assert_eq!(d.column(1).type_info().name(), "INTEGER");
assert_eq!(d.nullable(1), Some(false));
Ok(())
}