From 6ae689cbc388b3e2d060d1ce7fb3a04216332558 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Kijewski?= Date: Mon, 10 Feb 2025 04:55:12 +0100 Subject: [PATCH] Re-use `Html` escaping code to implement JSON escaping --- Cargo.toml | 2 +- rinja/Cargo.toml | 2 +- rinja/src/ascii_str.rs | 13 +++ rinja/src/filters/json.rs | 182 +++++++++++++++++++++++++++++--------- 4 files changed, 154 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 57b393a3..843cc0a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,6 @@ members = [ "testing", "testing-alloc", "testing-no-std", - "testing-renamed" + "testing-renamed", ] resolver = "2" diff --git a/rinja/Cargo.toml b/rinja/Cargo.toml index d061508c..0f552c4c 100644 --- a/rinja/Cargo.toml +++ b/rinja/Cargo.toml @@ -57,7 +57,7 @@ blocks = ["rinja_derive?/blocks"] code-in-doc = ["rinja_derive?/code-in-doc"] config = ["rinja_derive?/config"] derive = ["rinja_derive"] -serde_json = ["rinja_derive?/serde_json", "dep:serde", "dep:serde_json"] +serde_json = ["std", "rinja_derive?/serde_json", "dep:serde", "dep:serde_json"] std = [ "alloc", "rinja_derive?/std", diff --git a/rinja/src/ascii_str.rs b/rinja/src/ascii_str.rs index 54674603..80dcc385 100644 --- a/rinja/src/ascii_str.rs +++ b/rinja/src/ascii_str.rs @@ -114,6 +114,19 @@ impl AsciiChar { Self::new(ALPHABET[d as usize % ALPHABET.len()]), ] } + + #[inline] + pub const fn two_hex_digits(d: u32) -> [Self; 2] { + const ALPHABET: &[u8; 16] = b"0123456789abcdef"; + + if d >= ALPHABET.len().pow(2) as u32 { + panic!(); + } + [ + Self::new(ALPHABET[d as usize / ALPHABET.len()]), + Self::new(ALPHABET[d as usize % ALPHABET.len()]), + ] + } } mod _ascii_char { diff --git a/rinja/src/filters/json.rs b/rinja/src/filters/json.rs index 2696064b..8a86aa52 100644 --- a/rinja/src/filters/json.rs +++ b/rinja/src/filters/json.rs @@ -4,9 +4,10 @@ use std::pin::Pin; use std::{fmt, io, str}; use serde::Serialize; -use serde_json::ser::{PrettyFormatter, Serializer, to_writer}; +use serde_json::ser::{CompactFormatter, PrettyFormatter, Serializer}; use super::FastWritable; +use crate::ascii_str::{AsciiChar, AsciiStr}; /// Serialize to JSON (requires `json` feature) /// @@ -187,9 +188,8 @@ where } impl FastWritable for ToJson { - #[inline] fn write_into(&self, f: &mut W) -> crate::Result<()> { - fmt_json(f, &self.value) + serialize(f, &self.value, CompactFormatter) } } @@ -201,9 +201,12 @@ impl fmt::Display for ToJson { } impl FastWritable for ToJsonPretty { - #[inline] fn write_into(&self, f: &mut W) -> crate::Result<()> { - fmt_json_pretty(f, &self.value, self.indent.as_indent()) + serialize( + f, + &self.value, + PrettyFormatter::with_indent(self.indent.as_indent().as_bytes()), + ) } } @@ -214,58 +217,151 @@ impl fmt::Display for ToJsonPretty { } } -fn fmt_json(dest: &mut W, value: &S) -> crate::Result<()> { - Ok(to_writer(JsonWriter(dest), value)?) -} +#[inline] +fn serialize(dest: &mut W, value: &S, formatter: F) -> Result<(), crate::Error> +where + S: Serialize + ?Sized, + W: fmt::Write + ?Sized, + F: serde_json::ser::Formatter, +{ + /// The struct must only ever be used with the output of `serde_json`. + /// `serde_json` only produces UTF-8 strings in its `io::Write::write()` calls, + /// and `` depends on this invariant. + struct JsonWriter<'a, W: fmt::Write + ?Sized>(&'a mut W); + + impl io::Write for JsonWriter<'_, W> { + /// Invariant: must be passed valid UTF-8 slices + #[inline] + fn write(&mut self, bytes: &[u8]) -> io::Result { + self.write_all(bytes)?; + Ok(bytes.len()) + } + + /// Invariant: must be passed valid UTF-8 slices + fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> { + // SAFETY: `serde_json` only writes valid strings + let string = unsafe { std::str::from_utf8_unchecked(bytes) }; + write_escaped_str(&mut *self.0, string) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + /// Invariant: no character that needs escaping is multi-byte character when encoded in UTF-8; + /// that is true for characters in ASCII range. + #[inline] + fn write_escaped_str(dest: &mut (impl fmt::Write + ?Sized), src: &str) -> fmt::Result { + // This implementation reads one byte after another. + // It's not very fast, but should work well enough until portable SIMD gets stabilized. + + let mut escaped_buf = ESCAPED_BUF_INIT; + let mut last = 0; + + for (index, byte) in src.bytes().enumerate() { + if let Some(escaped) = get_escaped(byte) { + [escaped_buf[4], escaped_buf[5]] = escaped; + write_str_if_nonempty(dest, &src[last..index])?; + dest.write_str(AsciiStr::from_slice(&escaped_buf[..ESCAPED_BUF_LEN]))?; + last = index + 1; + } + } + write_str_if_nonempty(dest, &src[last..]) + } -fn fmt_json_pretty( - dest: &mut W, - value: &S, - indent: &str, -) -> crate::Result<()> { - let formatter = PrettyFormatter::with_indent(indent.as_bytes()); let mut serializer = Serializer::with_formatter(JsonWriter(dest), formatter); Ok(value.serialize(&mut serializer)?) } -struct JsonWriter<'a, W: fmt::Write + ?Sized>(&'a mut W); +/// Returns the decimal representation of the codepoint if the character needs HTML escaping. +#[inline] +fn get_escaped(byte: u8) -> Option<[AsciiChar; 2]> { + const _: () = assert!(CHAR_RANGE < 32); -impl io::Write for JsonWriter<'_, W> { - #[inline] - fn write(&mut self, bytes: &[u8]) -> io::Result { - self.write_all(bytes)?; - Ok(bytes.len()) + if let MIN_CHAR..=MAX_CHAR = byte { + if (1u32 << (byte - MIN_CHAR)) & BITS != 0 { + return Some(TABLE.0[(byte - MIN_CHAR) as usize]); + } } + None +} - #[inline] - fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> { - write(self.0, bytes).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) - } - - #[inline] - fn flush(&mut self) -> io::Result<()> { +#[inline(always)] +fn write_str_if_nonempty(output: &mut (impl fmt::Write + ?Sized), input: &str) -> fmt::Result { + if !input.is_empty() { + output.write_str(input) + } else { Ok(()) } } -fn write(f: &mut W, bytes: &[u8]) -> fmt::Result { - let mut last = 0; - for (index, byte) in bytes.iter().enumerate() { - let escaped = match byte { - b'&' => Some(br"\u0026"), - b'\'' => Some(br"\u0027"), - b'<' => Some(br"\u003c"), - b'>' => Some(br"\u003e"), - _ => None, - }; - if let Some(escaped) = escaped { - f.write_str(unsafe { str::from_utf8_unchecked(&bytes[last..index]) })?; - f.write_str(unsafe { str::from_utf8_unchecked(escaped) })?; - last = index + 1; +/// List of characters that need HTML escaping, not necessarily in ordinal order. +const CHARS: &[u8] = br#"&'<>"#; + +/// The character with the lowest codepoint that needs HTML escaping. +const MIN_CHAR: u8 = { + let mut v = u8::MAX; + let mut i = 0; + while i < CHARS.len() { + if v > CHARS[i] { + v = CHARS[i]; } + i += 1; } - f.write_str(unsafe { str::from_utf8_unchecked(&bytes[last..]) }) -} + v +}; + +/// The character with the highest codepoint that needs HTML escaping. +const MAX_CHAR: u8 = { + let mut v = u8::MIN; + let mut i = 0; + while i < CHARS.len() { + if v < CHARS[i] { + v = CHARS[i]; + } + i += 1; + } + v +}; + +const BITS: u32 = { + let mut bits = 0; + let mut i = 0; + while i < CHARS.len() { + bits |= 1 << (CHARS[i] - MIN_CHAR); + i += 1; + } + bits +}; + +/// Number of codepoints between the lowest and highest character that needs escaping, incl. +const CHAR_RANGE: usize = (MAX_CHAR - MIN_CHAR + 1) as usize; + +#[repr(align(64))] +struct Table([[AsciiChar; 2]; CHAR_RANGE]); + +/// For characters that need HTML escaping, the codepoint is formatted as decimal digits, +/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`]. +const TABLE: &Table = &{ + let mut table = Table([UNESCAPED; CHAR_RANGE]); + let mut i = 0; + while i < CHARS.len() { + let c = CHARS[i]; + table.0[c as u32 as usize - MIN_CHAR as usize] = AsciiChar::two_hex_digits(c as u32); + i += 1; + } + table +}; + +const UNESCAPED: [AsciiChar; 2] = AsciiStr::new_sized(""); + +const ESCAPED_BUF_INIT_UNPADDED: &str = "\\u00__"; +// RATIONALE: llvm generates better code if the buffer is register sized +const ESCAPED_BUF_INIT: [AsciiChar; 8] = AsciiStr::new_sized(ESCAPED_BUF_INIT_UNPADDED); +const ESCAPED_BUF_LEN: usize = ESCAPED_BUF_INIT_UNPADDED.len(); #[cfg(all(test, feature = "alloc"))] mod tests {