Support using both nullability and type overrides (#549)

* Make it possible to use both nullability and type overrides

* Fix override parsing lookahead logic

* Update column override tests

* Support nullability overrides with wildcard type overrides

* Fix tests

* Update query! overrides docs

* Remove last bits of macro_result!

* rustfmt
This commit is contained in:
Raphaël Thériault 2020-07-27 03:43:35 -04:00 committed by GitHub
parent 116fbc1942
commit ced09e0545
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 414 additions and 108 deletions

View File

@ -83,14 +83,10 @@ pub(crate) fn expand_migrator_from_dir(dir: LitStr) -> crate::Result<proc_macro2
migrations.sort_by_key(|m| m.version);
Ok(quote! {
macro_rules! macro_result {
() => {
sqlx::migrate::Migrator {
migrations: std::borrow::Cow::Borrowed(&[
#(#migrations),*
])
}
}
sqlx::migrate::Migrator {
migrations: std::borrow::Cow::Borrowed(&[
#(#migrations),*
])
}
})
}

View File

@ -223,7 +223,7 @@ where
let record_name: Type = syn::parse_str("Record").unwrap();
for rust_col in &columns {
if rust_col.type_.is_none() {
if rust_col.type_.is_wildcard() {
return Err(
"columns may not have wildcard overrides in `query!()` or `query_as!()"
.into(),

View File

@ -1,5 +1,5 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, ToTokens};
use quote::{quote, ToTokens, TokenStreamExt};
use syn::Type;
use sqlx_core::column::Column;
@ -14,7 +14,32 @@ use syn::Token;
pub struct RustColumn {
pub(super) ident: Ident,
pub(super) type_: Option<TokenStream>,
pub(super) type_: ColumnType,
}
pub(super) enum ColumnType {
Exact(TokenStream),
Wildcard,
OptWildcard,
}
impl ColumnType {
pub(super) fn is_wildcard(&self) -> bool {
match self {
ColumnType::Exact(_) => false,
_ => true,
}
}
}
impl ToTokens for ColumnType {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.append_all(match &self {
ColumnType::Exact(type_) => type_.clone().into_iter(),
ColumnType::Wildcard => quote! { _ }.into_iter(),
ColumnType::OptWildcard => quote! { Option<_> }.into_iter(),
})
}
}
struct DisplayColumn<'a> {
@ -25,15 +50,25 @@ struct DisplayColumn<'a> {
struct ColumnDecl {
ident: Ident,
// TIL Rust still has OOP keywords like `abstract`, `final`, `override` and `virtual` reserved
r#override: Option<ColumnOverride>,
r#override: ColumnOverride,
}
enum ColumnOverride {
struct ColumnOverride {
nullability: ColumnNullabilityOverride,
type_: ColumnTypeOverride,
}
#[derive(PartialEq)]
enum ColumnNullabilityOverride {
NonNull,
Nullable,
Wildcard,
None,
}
enum ColumnTypeOverride {
Exact(Type),
Wildcard,
None,
}
impl Display for DisplayColumn<'_> {
@ -52,22 +87,30 @@ pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Resul
let decl = ColumnDecl::parse(&column.name())
.map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?;
let type_ = match decl.r#override {
Some(ColumnOverride::Exact(ty)) => Some(ty.to_token_stream()),
Some(ColumnOverride::Wildcard) => None,
// these three could be combined but I prefer the clarity here
Some(ColumnOverride::NonNull) => Some(get_column_type::<DB>(i, column)),
Some(ColumnOverride::Nullable) => {
let type_ = get_column_type::<DB>(i, column);
Some(quote! { Option<#type_> })
}
None => {
let type_ = get_column_type::<DB>(i, column);
let ColumnOverride { nullability, type_ } = decl.r#override;
if !describe.nullable(i).unwrap_or(true) {
Some(type_)
let nullable = match nullability {
ColumnNullabilityOverride::NonNull => false,
ColumnNullabilityOverride::Nullable => true,
ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true),
};
let type_ = match (type_, nullable) {
(ColumnTypeOverride::Exact(type_), false) => {
ColumnType::Exact(type_.to_token_stream())
}
(ColumnTypeOverride::Exact(type_), true) => {
ColumnType::Exact(quote! { Option<#type_> })
}
(ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard,
(ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard,
(ColumnTypeOverride::None, _) => {
let type_ = get_column_type::<DB>(i, column);
if !nullable {
ColumnType::Exact(type_)
} else {
Some(quote! { Option<#type_> })
ColumnType::Exact(quote! { Option<#type_> })
}
}
};
@ -97,14 +140,17 @@ pub fn quote_query_as<DB: DatabaseExt>(
)| {
match (input.checked, type_) {
// we guarantee the type is valid so we can skip the runtime check
(true, Some(type_)) => quote! {
(true, ColumnType::Exact(type_)) => quote! {
// binding to a `let` avoids confusing errors about
// "try expression alternatives have incompatible types"
// it doesn't seem to hurt inference in the other branches
let #ident = row.try_get_unchecked::<#type_, _>(#i)?;
},
// type was overridden to be a wildcard so we fallback to the runtime check
(true, None) => quote! ( let #ident = row.try_get(#i)?; ),
(true, ColumnType::Wildcard) => quote! ( let #ident = row.try_get(#i)?; ),
(true, ColumnType::OptWildcard) => {
quote! ( let #ident = row.try_get::<Option<_>, _>(#i)?; )
}
// macro is the `_unchecked!()` variant so this will die in decoding if it's wrong
(false, _) => quote!( let #ident = row.try_get_unchecked(#i)?; ),
}
@ -176,9 +222,12 @@ impl ColumnDecl {
Ok(ColumnDecl {
ident,
r#override: if !remainder.is_empty() {
Some(syn::parse_str(remainder)?)
syn::parse_str(remainder)?
} else {
None
ColumnOverride {
nullability: ColumnNullabilityOverride::None,
type_: ColumnTypeOverride::None,
}
},
})
}
@ -188,27 +237,33 @@ impl Parse for ColumnOverride {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(Token![:]) {
let nullability = if lookahead.peek(Token![!]) {
input.parse::<Token![!]>()?;
ColumnNullabilityOverride::NonNull
} else if lookahead.peek(Token![?]) {
input.parse::<Token![?]>()?;
ColumnNullabilityOverride::Nullable
} else {
ColumnNullabilityOverride::None
};
let type_ = if input.lookahead1().peek(Token![:]) {
input.parse::<Token![:]>()?;
let ty = Type::parse(input)?;
if let Type::Infer(_) = ty {
Ok(ColumnOverride::Wildcard)
ColumnTypeOverride::Wildcard
} else {
Ok(ColumnOverride::Exact(ty))
ColumnTypeOverride::Exact(ty)
}
} else if lookahead.peek(Token![!]) {
input.parse::<Token![!]>()?;
Ok(ColumnOverride::NonNull)
} else if lookahead.peek(Token![?]) {
input.parse::<Token![?]>()?;
Ok(ColumnOverride::Nullable)
} else {
Err(lookahead.error())
}
ColumnTypeOverride::None
};
Ok(Self { nullability, type_ })
}
}

View File

@ -215,7 +215,8 @@
/// Selecting a column `foo as "foo: T"` (Postgres / SQLite) or `` foo as `foo: T` `` (MySQL)
/// overrides the inferred type which is useful when selecting user-defined custom types
/// (dynamic type checking is still done so if the types are incompatible this will be an error
/// at runtime instead of compile-time):
/// at runtime instead of compile-time). Note that this syntax alone doesn't override inferred nullability,
/// but it is compatible with the forced not-null and forced nullable annotations:
///
/// ```rust,ignore
/// # async fn main() {
@ -227,15 +228,27 @@
/// let my_int = MyInt4(1);
///
/// // Postgres/SQLite
/// sqlx::query!(r#"select 1 as "id: MyInt4""#) // MySQL: use "select 1 as `id: MyInt4`" instead
/// sqlx::query!(r#"select 1 as "id!: MyInt4""#) // MySQL: use "select 1 as `id: MyInt4`" instead
/// .fetch_one(&mut conn)
/// .await?;
///
/// // For Postgres this would have been inferred to be `Option<i32>`, MySQL/SQLite `i32`
/// // Note that while using `id: MyInt4` (without the `!`) would work the same for MySQL/SQLite,
/// // Postgres would expect `Some(MyInt4(1))` and the code wouldn't compile
/// assert_eq!(record.id, MyInt4(1));
/// # }
/// ```
///
/// ##### Overrides cheatsheet
///
/// | Syntax | Nullability | Type |
/// | --------- | --------------- | ---------- |
/// | `foo!` | Forced not-null | Inferred |
/// | `foo?` | Forced nullable | Inferred |
/// | `foo: T` | Inferred | Overridden |
/// | `foo!: T` | Forced not-null | Overridden |
/// | `foo?: T` | Forced nullable | Overridden |
///
/// ## Offline Mode (requires the `offline` feature)
/// The macros can be configured to not require a live database connection for compilation,
/// but it requires a couple extra steps:
@ -601,18 +614,10 @@ macro_rules! query_file_as_unchecked (
#[macro_export]
macro_rules! migrate {
($dir:literal) => {{
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::migrate!($dir);
}
macro_result!()
$crate::sqlx_macros::migrate!($dir)
}};
() => {{
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::migrate!("migrations");
}
macro_result!()
$crate::sqlx_macros::migrate!("migrations")
}};
}

View File

@ -1,8 +1,4 @@
use futures::{FutureExt, TryFutureExt};
use sqlx::any::AnyPoolOptions;
use sqlx::prelude::*;
use sqlx_core::any::AnyPool;
use sqlx_test::new;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
@ -16,7 +12,7 @@ async fn pool_should_invoke_after_connect() -> anyhow::Result<()> {
let pool = AnyPoolOptions::new()
.after_connect({
let counter = counter.clone();
move |conn| {
move |_conn| {
let counter = counter.clone();
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);

View File

@ -1,4 +1,4 @@
use sqlx::MySql;
use sqlx::{Connection, MySql, MySqlConnection, Transaction};
use sqlx_test::new;
#[sqlx_macros::test]
@ -120,9 +120,137 @@ async fn test_column_override_nullable() -> anyhow::Result<()> {
Ok(())
}
async fn with_test_row<'a>(
conn: &'a mut MySqlConnection,
) -> anyhow::Result<Transaction<'a, MySql>> {
let mut transaction = conn.begin().await?;
sqlx::query!("INSERT INTO tweet(id, text, owner_id) VALUES (1, '#sqlx is pretty cool!', 1)")
.execute(&mut transaction)
.await?;
Ok(transaction)
}
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct MyInt4(i32);
struct MyInt(i64);
struct Record {
id: MyInt,
}
struct OptionalRecord {
id: Option<MyInt>,
}
#[sqlx_macros::test]
async fn test_column_override_wildcard() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query_as!(Record, "select id as `id: _` from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
// this syntax is also useful for expressions
let record = sqlx::query_as!(Record, "select * from (select 1 as `id: _`) records")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
let record = sqlx::query_as!(OptionalRecord, "select owner_id as `id: _` from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_wildcard_not_null() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query_as!(Record, "select owner_id as `id!: _` from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_wildcard_nullable() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query_as!(OptionalRecord, "select id as `id?: _` from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query!("select id as `id: MyInt` from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
// we can also support this syntax for expressions
let record = sqlx::query!("select * from (select 1 as `id: MyInt`) records")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
let record = sqlx::query!("select owner_id as `id: MyInt` from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact_not_null() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query!("select owner_id as `id!: MyInt` from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact_nullable() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query!("select id as `id?: MyInt` from tweet")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
#[sqlx(rename_all = "lowercase")]
@ -140,36 +268,6 @@ enum MyCEnum {
Blue,
}
#[sqlx_macros::test]
async fn test_column_override_wildcard() -> anyhow::Result<()> {
struct Record {
id: MyInt4,
}
let mut conn = new::<MySql>().await?;
let record = sqlx::query_as!(Record, "select * from (select 1 as `id: _`) records")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt4(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
let record = sqlx::query!("select * from (select 1 as `id: MyInt4`) records")
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt4(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact_enum() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

View File

@ -1,4 +1,4 @@
use sqlx::{Connection, Postgres};
use sqlx::{Connection, PgConnection, Postgres, Transaction};
use sqlx_test::new;
use futures::TryStreamExt;
@ -351,23 +351,76 @@ async fn test_column_override_nullable() -> anyhow::Result<()> {
Ok(())
}
async fn with_test_row<'a>(
conn: &'a mut PgConnection,
) -> anyhow::Result<Transaction<'a, Postgres>> {
let mut transaction = conn.begin().await?;
sqlx::query!("INSERT INTO tweet(id, text, owner_id) VALUES (1, '#sqlx is pretty cool!', 1)")
.execute(&mut transaction)
.await?;
Ok(transaction)
}
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct MyInt(i64);
#[derive(PartialEq, Eq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct MyInt4(i32);
struct Record {
id: MyInt,
}
struct OptionalRecord {
id: Option<MyInt>,
}
#[sqlx_macros::test]
async fn test_column_override_wildcard() -> anyhow::Result<()> {
struct Record {
id: MyInt4,
}
let mut conn = new::<Postgres>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query_as!(Record, r#"select 1 as "id: _""#)
let record = sqlx::query_as!(Record, r#"select id as "id: _" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt4(1));
assert_eq!(record.id, MyInt(1));
let record = sqlx::query_as!(OptionalRecord, r#"select owner_id as "id: _" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_wildcard_not_null() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query_as!(Record, r#"select owner_id as "id!: _" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_wildcard_nullable() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query_as!(OptionalRecord, r#"select id as "id?: _" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
@ -375,12 +428,47 @@ async fn test_column_override_wildcard() -> anyhow::Result<()> {
#[sqlx_macros::test]
async fn test_column_override_exact() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query!(r#"select 1 as "id: MyInt4""#)
let record = sqlx::query!(r#"select id as "id: MyInt" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt4(1));
assert_eq!(record.id, MyInt(1));
let record = sqlx::query!(r#"select owner_id as "id: MyInt" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact_not_null() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query!(r#"select owner_id as "id!: MyInt" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact_nullable() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
let mut conn = with_test_row(&mut conn).await?;
let record = sqlx::query!(r#"select id as "id?: MyInt" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}

View File

@ -156,12 +156,16 @@ async fn test_column_override_nullable() -> anyhow::Result<()> {
#[sqlx(transparent)]
struct MyInt(i64);
struct Record {
id: MyInt,
}
struct OptionalRecord {
id: Option<MyInt>,
}
#[sqlx_macros::test]
async fn test_column_override_wildcard() -> anyhow::Result<()> {
struct Record {
id: MyInt,
}
let mut conn = new::<Sqlite>().await?;
let record = sqlx::query_as!(Record, r#"select id as "id: _" from tweet"#)
@ -177,6 +181,38 @@ async fn test_column_override_wildcard() -> anyhow::Result<()> {
assert_eq!(record.id, MyInt(1));
let record = sqlx::query_as!(OptionalRecord, r#"select owner_id as "id: _" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_wildcard_not_null() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let record = sqlx::query_as!(Record, r#"select owner_id as "id!: _" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_wildcard_nullable() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let record = sqlx::query_as!(OptionalRecord, r#"select id as "id?: _" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
@ -197,6 +233,38 @@ async fn test_column_override_exact() -> anyhow::Result<()> {
assert_eq!(record.id, MyInt(1));
let record = sqlx::query!(r#"select owner_id as "id: MyInt" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact_not_null() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let record = sqlx::query!(r#"select owner_id as "id!: MyInt" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, MyInt(1));
Ok(())
}
#[sqlx_macros::test]
async fn test_column_override_exact_nullable() -> anyhow::Result<()> {
let mut conn = new::<Sqlite>().await?;
let record = sqlx::query!(r#"select id as "id?: MyInt" from tweet"#)
.fetch_one(&mut conn)
.await?;
assert_eq!(record.id, Some(MyInt(1)));
Ok(())
}