diff --git a/sqlx-postgres/src/type_info.rs b/sqlx-postgres/src/type_info.rs index 4c902009..b01a1bfa 100644 --- a/sqlx-postgres/src/type_info.rs +++ b/sqlx-postgres/src/type_info.rs @@ -1154,7 +1154,103 @@ impl PartialEq for PgType { true } else { // Otherwise, perform a match on the name - self.name().eq_ignore_ascii_case(other.name()) + name_eq(self.name(), other.name()) } } } + +/// Check type names for equality, respecting Postgres' case sensitivity rules for identifiers. +/// +/// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +fn name_eq(name1: &str, name2: &str) -> bool { + // Cop-out of processing Unicode escapes by just using string equality. + if name1.starts_with("U&") { + // If `name2` doesn't start with `U&` this will automatically be `false`. + return name1 == name2; + } + + let mut chars1 = identifier_chars(name1); + let mut chars2 = identifier_chars(name2); + + while let (Some(a), Some(b)) = (chars1.next(), chars2.next()) { + if !a.eq(&b) { + return false; + } + } + + chars1.next().is_none() && chars2.next().is_none() +} + +struct IdentifierChar { + ch: char, + case_sensitive: bool, +} + +impl IdentifierChar { + fn eq(&self, other: &Self) -> bool { + if self.case_sensitive || other.case_sensitive { + self.ch == other.ch + } else { + self.ch.eq_ignore_ascii_case(&other.ch) + } + } +} + +/// Return an iterator over all significant characters of an identifier. +/// +/// Ignores non-escaped quotation marks. +fn identifier_chars(ident: &str) -> impl Iterator + '_ { + let mut case_sensitive = false; + let mut last_char_quote = false; + + ident.chars().filter_map(move |ch| { + if ch == '"' { + if last_char_quote { + last_char_quote = false; + } else { + last_char_quote = true; + return None; + } + } else if last_char_quote { + last_char_quote = false; + case_sensitive = !case_sensitive; + } + + Some(IdentifierChar { ch, case_sensitive }) + }) +} + +#[test] +fn test_name_eq() { + let test_values = [ + ("foo", "foo", true), + ("foo", "Foo", true), + ("foo", "FOO", true), + ("foo", r#""foo""#, true), + ("foo", r#""Foo""#, false), + ("foo", "foo.foo", false), + ("foo.foo", "foo.foo", true), + ("foo.foo", "foo.Foo", true), + ("foo.foo", "foo.FOO", true), + ("foo.foo", "Foo.foo", true), + ("foo.foo", "Foo.Foo", true), + ("foo.foo", "FOO.FOO", true), + ("foo.foo", "foo", false), + ("foo.foo", r#"foo."foo""#, true), + ("foo.foo", r#"foo."Foo""#, false), + ("foo.foo", r#"foo."FOO""#, false), + ]; + + for (left, right, eq) in test_values { + assert_eq!( + name_eq(left, right), + eq, + "failed check for name_eq({left:?}, {right:?})" + ); + assert_eq!( + name_eq(right, left), + eq, + "failed check for name_eq({right:?}, {left:?})" + ); + } +}