Re-use Html escaping code to implement JSON escaping

This commit is contained in:
René Kijewski 2025-02-10 04:55:12 +01:00
parent 84edf1cc77
commit 6ae689cbc3
4 changed files with 154 additions and 45 deletions

View File

@ -6,6 +6,6 @@ members = [
"testing", "testing",
"testing-alloc", "testing-alloc",
"testing-no-std", "testing-no-std",
"testing-renamed" "testing-renamed",
] ]
resolver = "2" resolver = "2"

View File

@ -57,7 +57,7 @@ blocks = ["rinja_derive?/blocks"]
code-in-doc = ["rinja_derive?/code-in-doc"] code-in-doc = ["rinja_derive?/code-in-doc"]
config = ["rinja_derive?/config"] config = ["rinja_derive?/config"]
derive = ["rinja_derive"] derive = ["rinja_derive"]
serde_json = ["rinja_derive?/serde_json", "dep:serde", "dep:serde_json"] serde_json = ["std", "rinja_derive?/serde_json", "dep:serde", "dep:serde_json"]
std = [ std = [
"alloc", "alloc",
"rinja_derive?/std", "rinja_derive?/std",

View File

@ -114,6 +114,19 @@ impl AsciiChar {
Self::new(ALPHABET[d as usize % ALPHABET.len()]), Self::new(ALPHABET[d as usize % ALPHABET.len()]),
] ]
} }
#[inline]
pub const fn two_hex_digits(d: u32) -> [Self; 2] {
const ALPHABET: &[u8; 16] = b"0123456789abcdef";
if d >= ALPHABET.len().pow(2) as u32 {
panic!();
}
[
Self::new(ALPHABET[d as usize / ALPHABET.len()]),
Self::new(ALPHABET[d as usize % ALPHABET.len()]),
]
}
} }
mod _ascii_char { mod _ascii_char {

View File

@ -4,9 +4,10 @@ use std::pin::Pin;
use std::{fmt, io, str}; use std::{fmt, io, str};
use serde::Serialize; use serde::Serialize;
use serde_json::ser::{PrettyFormatter, Serializer, to_writer}; use serde_json::ser::{CompactFormatter, PrettyFormatter, Serializer};
use super::FastWritable; use super::FastWritable;
use crate::ascii_str::{AsciiChar, AsciiStr};
/// Serialize to JSON (requires `json` feature) /// Serialize to JSON (requires `json` feature)
/// ///
@ -187,9 +188,8 @@ where
} }
impl<S: Serialize> FastWritable for ToJson<S> { impl<S: Serialize> FastWritable for ToJson<S> {
#[inline]
fn write_into<W: fmt::Write + ?Sized>(&self, f: &mut W) -> crate::Result<()> { fn write_into<W: fmt::Write + ?Sized>(&self, f: &mut W) -> crate::Result<()> {
fmt_json(f, &self.value) serialize(f, &self.value, CompactFormatter)
} }
} }
@ -201,9 +201,12 @@ impl<S: Serialize> fmt::Display for ToJson<S> {
} }
impl<S: Serialize, I: AsIndent> FastWritable for ToJsonPretty<S, I> { impl<S: Serialize, I: AsIndent> FastWritable for ToJsonPretty<S, I> {
#[inline]
fn write_into<W: fmt::Write + ?Sized>(&self, f: &mut W) -> crate::Result<()> { fn write_into<W: fmt::Write + ?Sized>(&self, f: &mut W) -> crate::Result<()> {
fmt_json_pretty(f, &self.value, self.indent.as_indent()) serialize(
f,
&self.value,
PrettyFormatter::with_indent(self.indent.as_indent().as_bytes()),
)
} }
} }
@ -214,58 +217,151 @@ impl<S: Serialize, I: AsIndent> fmt::Display for ToJsonPretty<S, I> {
} }
} }
fn fmt_json<S: Serialize, W: fmt::Write + ?Sized>(dest: &mut W, value: &S) -> crate::Result<()> { #[inline]
Ok(to_writer(JsonWriter(dest), value)?) fn serialize<S, W, F>(dest: &mut W, value: &S, formatter: F) -> Result<(), crate::Error>
} where
S: Serialize + ?Sized,
W: fmt::Write + ?Sized,
F: serde_json::ser::Formatter,
{
/// The struct must only ever be used with the output of `serde_json`.
/// `serde_json` only produces UTF-8 strings in its `io::Write::write()` calls,
/// and `<JsonWriter as io::Write>` depends on this invariant.
struct JsonWriter<'a, W: fmt::Write + ?Sized>(&'a mut W);
impl<W: fmt::Write + ?Sized> io::Write for JsonWriter<'_, W> {
/// Invariant: must be passed valid UTF-8 slices
#[inline]
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
self.write_all(bytes)?;
Ok(bytes.len())
}
/// Invariant: must be passed valid UTF-8 slices
fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> {
// SAFETY: `serde_json` only writes valid strings
let string = unsafe { std::str::from_utf8_unchecked(bytes) };
write_escaped_str(&mut *self.0, string)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
/// Invariant: no character that needs escaping is multi-byte character when encoded in UTF-8;
/// that is true for characters in ASCII range.
#[inline]
fn write_escaped_str(dest: &mut (impl fmt::Write + ?Sized), src: &str) -> fmt::Result {
// This implementation reads one byte after another.
// It's not very fast, but should work well enough until portable SIMD gets stabilized.
let mut escaped_buf = ESCAPED_BUF_INIT;
let mut last = 0;
for (index, byte) in src.bytes().enumerate() {
if let Some(escaped) = get_escaped(byte) {
[escaped_buf[4], escaped_buf[5]] = escaped;
write_str_if_nonempty(dest, &src[last..index])?;
dest.write_str(AsciiStr::from_slice(&escaped_buf[..ESCAPED_BUF_LEN]))?;
last = index + 1;
}
}
write_str_if_nonempty(dest, &src[last..])
}
fn fmt_json_pretty<S: Serialize, W: fmt::Write + ?Sized>(
dest: &mut W,
value: &S,
indent: &str,
) -> crate::Result<()> {
let formatter = PrettyFormatter::with_indent(indent.as_bytes());
let mut serializer = Serializer::with_formatter(JsonWriter(dest), formatter); let mut serializer = Serializer::with_formatter(JsonWriter(dest), formatter);
Ok(value.serialize(&mut serializer)?) Ok(value.serialize(&mut serializer)?)
} }
struct JsonWriter<'a, W: fmt::Write + ?Sized>(&'a mut W); /// Returns the decimal representation of the codepoint if the character needs HTML escaping.
#[inline]
fn get_escaped(byte: u8) -> Option<[AsciiChar; 2]> {
const _: () = assert!(CHAR_RANGE < 32);
impl<W: fmt::Write + ?Sized> io::Write for JsonWriter<'_, W> { if let MIN_CHAR..=MAX_CHAR = byte {
#[inline] if (1u32 << (byte - MIN_CHAR)) & BITS != 0 {
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> { return Some(TABLE.0[(byte - MIN_CHAR) as usize]);
self.write_all(bytes)?; }
Ok(bytes.len())
} }
None
}
#[inline] #[inline(always)]
fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> { fn write_str_if_nonempty(output: &mut (impl fmt::Write + ?Sized), input: &str) -> fmt::Result {
write(self.0, bytes).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) if !input.is_empty() {
} output.write_str(input)
} else {
#[inline]
fn flush(&mut self) -> io::Result<()> {
Ok(()) Ok(())
} }
} }
fn write<W: fmt::Write + ?Sized>(f: &mut W, bytes: &[u8]) -> fmt::Result { /// List of characters that need HTML escaping, not necessarily in ordinal order.
let mut last = 0; const CHARS: &[u8] = br#"&'<>"#;
for (index, byte) in bytes.iter().enumerate() {
let escaped = match byte { /// The character with the lowest codepoint that needs HTML escaping.
b'&' => Some(br"\u0026"), const MIN_CHAR: u8 = {
b'\'' => Some(br"\u0027"), let mut v = u8::MAX;
b'<' => Some(br"\u003c"), let mut i = 0;
b'>' => Some(br"\u003e"), while i < CHARS.len() {
_ => None, if v > CHARS[i] {
}; v = CHARS[i];
if let Some(escaped) = escaped {
f.write_str(unsafe { str::from_utf8_unchecked(&bytes[last..index]) })?;
f.write_str(unsafe { str::from_utf8_unchecked(escaped) })?;
last = index + 1;
} }
i += 1;
} }
f.write_str(unsafe { str::from_utf8_unchecked(&bytes[last..]) }) v
} };
/// The character with the highest codepoint that needs HTML escaping.
const MAX_CHAR: u8 = {
let mut v = u8::MIN;
let mut i = 0;
while i < CHARS.len() {
if v < CHARS[i] {
v = CHARS[i];
}
i += 1;
}
v
};
const BITS: u32 = {
let mut bits = 0;
let mut i = 0;
while i < CHARS.len() {
bits |= 1 << (CHARS[i] - MIN_CHAR);
i += 1;
}
bits
};
/// Number of codepoints between the lowest and highest character that needs escaping, incl.
const CHAR_RANGE: usize = (MAX_CHAR - MIN_CHAR + 1) as usize;
#[repr(align(64))]
struct Table([[AsciiChar; 2]; CHAR_RANGE]);
/// For characters that need HTML escaping, the codepoint is formatted as decimal digits,
/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`].
const TABLE: &Table = &{
let mut table = Table([UNESCAPED; CHAR_RANGE]);
let mut i = 0;
while i < CHARS.len() {
let c = CHARS[i];
table.0[c as u32 as usize - MIN_CHAR as usize] = AsciiChar::two_hex_digits(c as u32);
i += 1;
}
table
};
const UNESCAPED: [AsciiChar; 2] = AsciiStr::new_sized("");
const ESCAPED_BUF_INIT_UNPADDED: &str = "\\u00__";
// RATIONALE: llvm generates better code if the buffer is register sized
const ESCAPED_BUF_INIT: [AsciiChar; 8] = AsciiStr::new_sized(ESCAPED_BUF_INIT_UNPADDED);
const ESCAPED_BUF_LEN: usize = ESCAPED_BUF_INIT_UNPADDED.len();
#[cfg(all(test, feature = "alloc"))] #[cfg(all(test, feature = "alloc"))]
mod tests { mod tests {