fix(postgres): case-aware type name equality

This commit is contained in:
Austin Bonander 2024-05-31 14:35:53 -07:00
parent 32143363bc
commit 2618439663

View File

@ -1154,7 +1154,103 @@ impl PartialEq<PgType> for PgType {
true
} else {
// Otherwise, perform a match on the name
self.name().eq_ignore_ascii_case(other.name())
name_eq(self.name(), other.name())
}
}
}
/// Check type names for equality, respecting Postgres' case sensitivity rules for identifiers.
///
/// https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
fn name_eq(name1: &str, name2: &str) -> bool {
// Cop-out of processing Unicode escapes by just using string equality.
if name1.starts_with("U&") {
// If `name2` doesn't start with `U&` this will automatically be `false`.
return name1 == name2;
}
let mut chars1 = identifier_chars(name1);
let mut chars2 = identifier_chars(name2);
while let (Some(a), Some(b)) = (chars1.next(), chars2.next()) {
if !a.eq(&b) {
return false;
}
}
chars1.next().is_none() && chars2.next().is_none()
}
struct IdentifierChar {
ch: char,
case_sensitive: bool,
}
impl IdentifierChar {
fn eq(&self, other: &Self) -> bool {
if self.case_sensitive || other.case_sensitive {
self.ch == other.ch
} else {
self.ch.eq_ignore_ascii_case(&other.ch)
}
}
}
/// Return an iterator over all significant characters of an identifier.
///
/// Ignores non-escaped quotation marks.
fn identifier_chars(ident: &str) -> impl Iterator<Item = IdentifierChar> + '_ {
let mut case_sensitive = false;
let mut last_char_quote = false;
ident.chars().filter_map(move |ch| {
if ch == '"' {
if last_char_quote {
last_char_quote = false;
} else {
last_char_quote = true;
return None;
}
} else if last_char_quote {
last_char_quote = false;
case_sensitive = !case_sensitive;
}
Some(IdentifierChar { ch, case_sensitive })
})
}
#[test]
fn test_name_eq() {
let test_values = [
("foo", "foo", true),
("foo", "Foo", true),
("foo", "FOO", true),
("foo", r#""foo""#, true),
("foo", r#""Foo""#, false),
("foo", "foo.foo", false),
("foo.foo", "foo.foo", true),
("foo.foo", "foo.Foo", true),
("foo.foo", "foo.FOO", true),
("foo.foo", "Foo.foo", true),
("foo.foo", "Foo.Foo", true),
("foo.foo", "FOO.FOO", true),
("foo.foo", "foo", false),
("foo.foo", r#"foo."foo""#, true),
("foo.foo", r#"foo."Foo""#, false),
("foo.foo", r#"foo."FOO""#, false),
];
for (left, right, eq) in test_values {
assert_eq!(
name_eq(left, right),
eq,
"failed check for name_eq({left:?}, {right:?})"
);
assert_eq!(
name_eq(right, left),
eq,
"failed check for name_eq({right:?}, {left:?})"
);
}
}