Implement enum variants

This commit is contained in:
René Kijewski 2024-11-19 03:00:11 +01:00
parent e418834149
commit 5944ab9bef
7 changed files with 553 additions and 45 deletions

View File

@ -172,7 +172,8 @@ jobs:
with:
tool: cargo-nextest
- uses: Swatinem/rust-cache@v2
- run: cd ${{ matrix.package }} && cargo nextest run --no-tests=warn
- run: cd ${{ matrix.package }} && cargo build --all-targets
- run: cd ${{ matrix.package }} && cargo nextest run --all-targets --no-fail-fast --no-tests=warn
- run: cd ${{ matrix.package }} && cargo clippy --all-targets -- -D warnings
MSRV:

View File

@ -269,3 +269,11 @@ impl<L: FastWritable, R: FastWritable> FastWritable for Concat<L, R> {
self.1.write_into(dest)
}
}
pub trait EnumVariantTemplate {
fn render_into_with_values<W: fmt::Write + ?Sized>(
&self,
writer: &mut W,
values: &dyn crate::Values,
) -> crate::Result<()>;
}

View File

@ -25,9 +25,12 @@ pub(crate) fn template_to_string(
input: &TemplateInput<'_>,
contexts: &HashMap<&Arc<Path>, Context<'_>, FxBuildHasher>,
heritage: Option<&Heritage<'_, '_>>,
target: Option<&str>,
tmpl_kind: TmplKind,
) -> Result<usize, CompileError> {
let ctx = &contexts[&input.path];
if tmpl_kind == TmplKind::Struct {
buf.write("const _: () = { extern crate rinja as rinja;");
}
let generator = Generator::new(
input,
contexts,
@ -36,13 +39,27 @@ pub(crate) fn template_to_string(
input.block.is_some(),
0,
);
let mut result = generator.build(ctx, buf, target);
if let Err(err) = &mut result {
if err.span.is_none() {
let size_hint = match generator.impl_template(buf, tmpl_kind) {
Err(mut err) if err.span.is_none() => {
err.span = input.source_span;
Err(err)
}
result => result,
}?;
if tmpl_kind == TmplKind::Struct {
impl_everything(input.ast, buf);
buf.write("};");
}
result
Ok(size_hint)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TmplKind {
/// [`rinja::Template`]
Struct,
/// [`rinja::helpers::EnumVariantTemplate`]
Variant,
}
struct Generator<'a, 'h> {
@ -97,31 +114,18 @@ impl<'a, 'h> Generator<'a, 'h> {
}
}
// Takes a Context and generates the relevant implementations.
fn build(
mut self,
ctx: &Context<'a>,
buf: &mut Buffer,
target: Option<&str>,
) -> Result<usize, CompileError> {
if target.is_none() {
buf.write("const _: () = { extern crate rinja as rinja;");
}
let size_hint = self.impl_template(ctx, buf, target.unwrap_or("rinja::Template"))?;
if target.is_none() {
impl_everything(self.input.ast, buf);
buf.write("};");
}
Ok(size_hint)
}
// Implement `Template` for the given context struct.
fn impl_template(
&mut self,
ctx: &Context<'a>,
mut self,
buf: &mut Buffer,
target: &str,
tmpl_kind: TmplKind,
) -> Result<usize, CompileError> {
let ctx = &self.contexts[&self.input.path];
let target = match tmpl_kind {
TmplKind::Struct => "rinja::Template",
TmplKind::Variant => "rinja::helpers::EnumVariantTemplate",
};
write_header(self.input.ast, buf, target);
buf.write(
"fn render_into_with_values<RinjaW>(\
@ -161,12 +165,12 @@ impl<'a, 'h> Generator<'a, 'h> {
let size_hint = self.impl_template_inner(ctx, buf)?;
buf.write(format_args!(
"\
rinja::Result::Ok(())\
}}\
const SIZE_HINT: rinja::helpers::core::primitive::usize = {size_hint}usize;",
));
buf.write("rinja::Result::Ok(()) }");
if tmpl_kind == TmplKind::Struct {
buf.write(format_args!(
"const SIZE_HINT: rinja::helpers::core::primitive::usize = {size_hint}usize;",
));
}
buf.write('}');
Ok(size_hint)

View File

@ -271,6 +271,63 @@ impl TemplateInput<'_> {
}
}
pub(crate) enum AnyTemplateArgs {
Struct(TemplateArgs),
Enum {
enum_args: Option<PartialTemplateArgs>,
vars_args: Vec<Option<PartialTemplateArgs>>,
has_default_impl: bool,
},
}
impl AnyTemplateArgs {
pub(crate) fn new(ast: &syn::DeriveInput) -> Result<Self, CompileError> {
let syn::Data::Enum(enum_data) = &ast.data else {
return Ok(Self::Struct(TemplateArgs::new(ast)?));
};
let enum_args = PartialTemplateArgs::new(ast, &ast.attrs)?;
let vars_args = enum_data
.variants
.iter()
.map(|variant| PartialTemplateArgs::new(ast, &variant.attrs))
.collect::<Result<Vec<_>, _>>()?;
if vars_args.is_empty() {
return Ok(Self::Struct(TemplateArgs::from_partial(ast, enum_args)?));
}
let mut needs_default_impl = vars_args.len();
let enum_source = enum_args.as_ref().and_then(|v| v.source.as_ref());
for (variant, var_args) in enum_data.variants.iter().zip(&vars_args) {
if var_args
.as_ref()
.and_then(|v| v.source.as_ref())
.or(enum_source)
.is_none()
{
return Err(CompileError::new_with_span(
#[cfg(not(feature = "code-in-doc"))]
"either all `enum` variants need a `path` or `source` argument, \
or the `enum` itself needs a default implementation",
#[cfg(feature = "code-in-doc")]
"either all `enum` variants need a `path`, `source` or `in_doc` argument, \
or the `enum` itself needs a default implementation",
None,
Some(variant.ident.span()),
));
} else if !var_args.is_none() {
needs_default_impl -= 1;
}
}
Ok(Self::Enum {
enum_args,
vars_args,
has_default_impl: needs_default_impl > 0,
})
}
}
#[derive(Debug)]
pub(crate) struct TemplateArgs {
pub(crate) source: (Source, Option<Span>),
@ -626,6 +683,17 @@ pub(crate) enum PartialTemplateArgsSource {
InDoc(Span, Source),
}
impl PartialTemplateArgsSource {
pub(crate) fn span(&self) -> Span {
match self {
Self::Path(s) => s.span(),
Self::Source(s) => s.span(),
#[cfg(feature = "code-in-doc")]
Self::InDoc(s, _) => s.span(),
}
}
}
// implement PartialTemplateArgs::new()
const _: () = {
impl PartialTemplateArgs {

View File

@ -1,7 +1,16 @@
use std::fmt::{Arguments, Display, Write};
use quote::quote;
use syn::DeriveInput;
use proc_macro2::{TokenStream, TokenTree};
use quote::{ToTokens, quote};
use syn::spanned::Spanned;
use syn::{
Data, DeriveInput, Fields, GenericParam, Generics, Ident, Lifetime, LifetimeParam, Token, Type,
Variant, parse_quote,
};
use crate::generator::TmplKind;
use crate::input::{PartialTemplateArgs, TemplateArgs};
use crate::{CompileError, build_template_item};
/// Implement every integration for the given item
pub(crate) fn impl_everything(ast: &DeriveInput, buf: &mut Buffer) {
@ -223,3 +232,291 @@ fn string_escape(dest: &mut String, src: &str) {
}
dest.extend(&src[last..]);
}
pub(crate) fn build_template_enum(
buf: &mut Buffer,
enum_ast: &DeriveInput,
mut enum_args: Option<PartialTemplateArgs>,
vars_args: Vec<Option<PartialTemplateArgs>>,
has_default_impl: bool,
) -> Result<usize, CompileError> {
let Data::Enum(enum_data) = &enum_ast.data else {
unreachable!();
};
buf.write("const _: () = { extern crate rinja as rinja;");
impl_everything(enum_ast, buf);
let enum_id = &enum_ast.ident;
let enum_span = enum_id.span();
let lifetime = Lifetime::new(&format!("'__Rinja_{enum_id}"), enum_span);
let mut generics = enum_ast.generics.clone();
if generics.lt_token.is_none() {
generics.lt_token = Some(Token![<](enum_span));
}
if generics.gt_token.is_none() {
generics.gt_token = Some(Token![>](enum_span));
}
generics
.params
.insert(0, GenericParam::Lifetime(LifetimeParam::new(lifetime)));
let mut biggest_size_hint = 0;
let mut render_into_arms = TokenStream::new();
let mut size_hint_arms = TokenStream::new();
for (var, var_args) in enum_data.variants.iter().zip(vars_args) {
let Some(mut var_args) = var_args else {
continue;
};
let var_ast = type_for_enum_variant(enum_ast, &generics, var);
buf.write(quote!(#var_ast).to_string());
// not inherited: template, meta_docs, block, print
if let Some(enum_args) = &mut enum_args {
set_default(&mut var_args, enum_args, |v| &mut v.source);
set_default(&mut var_args, enum_args, |v| &mut v.escape);
set_default(&mut var_args, enum_args, |v| &mut v.ext);
set_default(&mut var_args, enum_args, |v| &mut v.syntax);
set_default(&mut var_args, enum_args, |v| &mut v.config);
set_default(&mut var_args, enum_args, |v| &mut v.whitespace);
}
let size_hint = biggest_size_hint.max(build_template_item(
buf,
&var_ast,
&TemplateArgs::from_partial(&var_ast, Some(var_args))?,
TmplKind::Variant,
)?);
biggest_size_hint = biggest_size_hint.max(size_hint);
variant_as_arm(
&var_ast,
var,
size_hint,
&mut render_into_arms,
&mut size_hint_arms,
);
}
if has_default_impl {
let size_hint = build_template_item(
buf,
enum_ast,
&TemplateArgs::from_partial(enum_ast, enum_args)?,
TmplKind::Variant,
)?;
biggest_size_hint = biggest_size_hint.max(size_hint);
render_into_arms.extend(quote! {
ref __rinja_arg => {
<_ as rinja::helpers::EnumVariantTemplate>::render_into_with_values(
__rinja_arg,
__rinja_writer,
__rinja_values,
)
}
});
size_hint_arms.extend(quote! {
_ => {
#size_hint
}
});
}
write_header(enum_ast, buf, "rinja::Template");
buf.write(format_args!(
"\
fn render_into_with_values<RinjaW>(\
&self,\
__rinja_writer: &mut RinjaW,\
__rinja_values: &dyn rinja::Values,\
) -> rinja::Result<()>\
where \
RinjaW: rinja::helpers::core::fmt::Write + ?rinja::helpers::core::marker::Sized\
{{\
match *self {{\
{render_into_arms}\
}}\
}}",
));
#[cfg(feature = "alloc")]
buf.write(format_args!(
"\
fn render_with_values(\
&self,\
__rinja_values: &dyn rinja::Values,\
) -> rinja::Result<rinja::helpers::alloc::string::String> {{\
let size_hint = match self {{\
{size_hint_arms}\
}};\
let mut buf = rinja::helpers::alloc::string::String::new();\
let _ = buf.try_reserve(size_hint);\
self.render_into_with_values(&mut buf, __rinja_values)?;\
rinja::Result::Ok(buf)\
}}",
));
buf.write(format_args!(
"\
const SIZE_HINT: rinja::helpers::core::primitive::usize = {biggest_size_hint}usize;\
}}\
}};",
));
Ok(biggest_size_hint)
}
fn set_default<S, T, A>(dest: &mut S, parent: &mut S, mut access: A)
where
T: Clone,
A: FnMut(&mut S) -> &mut Option<T>,
{
let dest = access(dest);
if dest.is_none() {
if let Some(parent) = access(parent) {
*dest = Some(parent.clone());
}
}
}
/// Generates a `struct` to contain the data of an enum variant
fn type_for_enum_variant(
enum_ast: &DeriveInput,
enum_generics: &Generics,
var: &Variant,
) -> DeriveInput {
let enum_id = &enum_ast.ident;
let (_, ty_generics, _) = enum_ast.generics.split_for_impl();
let lt = enum_generics.params.first().unwrap();
let id = &var.ident;
let span = id.span();
let id = Ident::new(&format!("__Rinja__{enum_id}__{id}"), span);
let phantom: Type = parse_quote! {
rinja::helpers::core::marker::PhantomData < &#lt #enum_id #ty_generics >
};
let fields = match &var.fields {
Fields::Named(fields) => {
let mut fields = fields.clone();
for f in fields.named.iter_mut() {
let ty = &f.ty;
f.ty = parse_quote!(&#lt #ty);
}
let id = Ident::new(&format!("__Rinja__{enum_id}__phantom"), span);
fields.named.push(parse_quote!(#id: #phantom));
Fields::Named(fields)
}
Fields::Unnamed(fields) => {
let mut fields = fields.clone();
for f in fields.unnamed.iter_mut() {
let ty = &f.ty;
f.ty = parse_quote!(&#lt #ty);
}
fields.unnamed.push(parse_quote!(#phantom));
Fields::Unnamed(fields)
}
Fields::Unit => Fields::Unnamed(parse_quote!((#phantom))),
};
let semicolon = match &var.fields {
Fields::Named(_) => None,
_ => Some(Token![;](span)),
};
parse_quote! {
#[rinja::helpers::core::prelude::rust_2021::derive(
rinja::helpers::core::prelude::rust_2021::Clone,
rinja::helpers::core::prelude::rust_2021::Copy,
rinja::helpers::core::prelude::rust_2021::Debug
)]
#[allow(dead_code, non_camel_case_types, non_snake_case)]
struct #id #enum_generics #fields #semicolon
}
}
/// Generates a `match` arm for an `enum` variant, that calls `<_ as EnumVariantTemplate>::render_into()`
/// for that type and data
fn variant_as_arm(
var_ast: &DeriveInput,
var: &Variant,
size_hint: usize,
render_into_arms: &mut TokenStream,
size_hint_arms: &mut TokenStream,
) {
let var_id = &var_ast.ident;
let ident = &var.ident;
let span = ident.span();
let generics = var_ast.generics.clone();
let (_, ty_generics, _) = generics.split_for_impl();
let ty_generics: TokenStream = ty_generics
.as_turbofish()
.to_token_stream()
.into_iter()
.enumerate()
.map(|(idx, token)| match idx {
// 0 1 2 3 4 => : : < ' __Rinja_Foo
4 => TokenTree::Ident(Ident::new("_", span)),
_ => token,
})
.collect();
let Data::Struct(ast_data) = &var_ast.data else {
unreachable!();
};
let mut src = TokenStream::new();
let mut this = TokenStream::new();
match &var.fields {
Fields::Named(fields) => {
for (idx, field) in fields.named.iter().enumerate() {
let arg = Ident::new(&format!("__rinja_arg_{idx}"), field.span());
let id = field.ident.as_ref().unwrap();
src.extend(quote!(#id: ref #arg,));
this.extend(quote!(#id: #arg,));
}
let phantom = match &ast_data.fields {
Fields::Named(fields) => fields
.named
.iter()
.next_back()
.unwrap()
.ident
.as_ref()
.unwrap(),
Fields::Unnamed(_) | Fields::Unit => unreachable!(),
};
this.extend(quote!(#phantom: rinja::helpers::core::marker::PhantomData {},));
}
Fields::Unnamed(fields) => {
for (idx, field) in fields.unnamed.iter().enumerate() {
let span = field.ident.span();
let arg = Ident::new(&format!("__rinja_arg_{idx}"), span);
let idx = syn::LitInt::new(&format!("{idx}"), span);
src.extend(quote!(#idx: ref #arg,));
this.extend(quote!(#idx: #arg,));
}
let idx = syn::LitInt::new(&format!("{}", fields.unnamed.len()), span);
this.extend(quote!(#idx: rinja::helpers::core::marker::PhantomData {},));
}
Fields::Unit => {
this.extend(quote!(0: rinja::helpers::core::marker::PhantomData {},));
}
};
render_into_arms.extend(quote! {
Self :: #ident { #src } => {
<_ as rinja::helpers::EnumVariantTemplate>::render_into_with_values(
& #var_id #ty_generics { #this },
__rinja_writer,
__rinja_values,
)
}
});
size_hint_arms.extend(quote! {
Self :: #ident { .. } => {
#size_hint
}
});
}

View File

@ -19,10 +19,10 @@ use std::path::Path;
use std::sync::Mutex;
use config::{Config, read_config_file};
use generator::template_to_string;
use generator::{TmplKind, template_to_string};
use heritage::{Context, Heritage};
use input::{Print, TemplateArgs, TemplateInput};
use integration::Buffer;
use input::{AnyTemplateArgs, Print, TemplateArgs, TemplateInput};
use integration::{Buffer, build_template_enum};
use parser::{Parsed, strip_common};
#[cfg(not(feature = "__standalone"))]
use proc_macro::TokenStream as TokenStream12;
@ -159,7 +159,7 @@ fn build_skeleton(buf: &mut Buffer, ast: &syn::DeriveInput) -> Result<usize, Com
let mut contexts = HashMap::default();
let parsed = parser::Parsed::default();
contexts.insert(&input.path, Context::empty(&parsed));
template_to_string(buf, &input, &contexts, None, None)
template_to_string(buf, &input, &contexts, None, TmplKind::Struct)
}
/// Takes a `syn::DeriveInput` and generates source code for it
@ -173,11 +173,28 @@ pub(crate) fn build_template(
buf: &mut Buffer,
ast: &syn::DeriveInput,
) -> Result<usize, CompileError> {
let template_args = TemplateArgs::new(ast)?;
let mut result = build_template_item(buf, ast, &template_args, None);
let err_span;
let mut result = match AnyTemplateArgs::new(ast)? {
AnyTemplateArgs::Struct(item) => {
err_span = item.source.1.or(item.template_span);
build_template_item(buf, ast, &item, TmplKind::Struct)
}
AnyTemplateArgs::Enum {
enum_args,
vars_args,
has_default_impl,
} => {
err_span = enum_args
.as_ref()
.and_then(|v| v.source.as_ref())
.map(|s| s.span())
.or_else(|| enum_args.as_ref().map(|v| v.template.span()));
build_template_enum(buf, ast, enum_args, vars_args, has_default_impl)
}
};
if let Err(err) = &mut result {
if err.span.is_none() {
err.span = template_args.source.1.or(template_args.template_span);
err.span = err_span;
}
}
result
@ -187,7 +204,7 @@ fn build_template_item(
buf: &mut Buffer,
ast: &syn::DeriveInput,
template_args: &TemplateArgs,
target: Option<&str>,
tmpl_kind: TmplKind,
) -> Result<usize, CompileError> {
let config_path = template_args.config_path();
let s = read_config_file(config_path, template_args.config_span)?;
@ -230,7 +247,7 @@ fn build_template_item(
}
let mark = buf.get_mark();
let size_hint = template_to_string(buf, &input, &contexts, heritage.as_ref(), target)?;
let size_hint = template_to_string(buf, &input, &contexts, heritage.as_ref(), tmpl_kind)?;
if input.print == Print::Code || input.print == Print::All {
eprintln!("{}", buf.marked_text(mark));
}

113
testing/tests/enum.rs Normal file
View File

@ -0,0 +1,113 @@
use std::any::type_name_of_val;
use std::fmt::{Debug, Display};
use rinja::Template;
#[test]
fn test_simple_enum() {
#[derive(Template, Debug)]
#[template(
ext = "txt",
source = "{{ self::type_name_of_val(self) }} | {{ self|fmt(\"{:?}\") }}"
)]
enum SimpleEnum<'a, B: Display + Debug> {
#[template(source = "A")]
A,
#[template(source = "B()")]
B(),
#[template(source = "C({{self.0}}, {{self.1}})")]
C(u32, u32),
#[template(source = "D {}")]
D {},
#[template(source = "E { a: {{a}}, b: {{b}} }")]
E {
a: &'a str,
b: B,
},
// uses default source with `SimpleEnum` as `Self`
F,
// uses default source with a synthetic type `__Rinja__SimpleEnum__G` as `Self`
#[template()]
G,
}
let tmpl: SimpleEnum<'_, X> = SimpleEnum::A;
assert_eq!(tmpl.render().unwrap(), "A");
let tmpl: SimpleEnum<'_, X> = SimpleEnum::B();
assert_eq!(tmpl.render().unwrap(), "B()");
let tmpl: SimpleEnum<'_, X> = SimpleEnum::C(12, 34);
assert_eq!(tmpl.render().unwrap(), "C(12, 34)");
let tmpl: SimpleEnum<'_, X> = SimpleEnum::C(12, 34);
assert_eq!(tmpl.render().unwrap(), "C(12, 34)");
let tmpl: SimpleEnum<'_, X> = SimpleEnum::D {};
assert_eq!(tmpl.render().unwrap(), "D {}");
let tmpl: SimpleEnum<'_, X> = SimpleEnum::E { a: "hello", b: X };
assert_eq!(tmpl.render().unwrap(), "E { a: hello, b: X }");
let tmpl: SimpleEnum<'_, X> = SimpleEnum::F;
assert_eq!(
tmpl.render().unwrap(),
"&enum::test_simple_enum::SimpleEnum<enum::X> | F",
);
let tmpl: SimpleEnum<'_, X> = SimpleEnum::G;
assert_eq!(
tmpl.render().unwrap(),
"&enum::test_simple_enum::_::__Rinja__SimpleEnum__G<enum::X> | \
__Rinja__SimpleEnum__G(\
PhantomData<&enum::test_simple_enum::SimpleEnum<enum::X>>\
)",
);
}
#[test]
fn test_enum_blocks() {
#[derive(Template, Debug)]
#[template(
ext = "txt",
source = "\
{% block a -%} <a = {{ a }}> {%- endblock %}
{% block b -%} <b = {{ b }}> {%- endblock %}
{% block c -%} <c = {{ c }}> {%- endblock %}
{% block d -%} <d = {{ self::type_name_of_val(self) }}> {%- endblock %}
"
)]
enum BlockEnum<'a, C: Display> {
#[template(block = "a")]
A { a: u32 },
#[template(block = "b")]
B { b: &'a str },
#[template(block = "c")]
C { c: C },
#[template(block = "d")]
D,
}
let tmpl: BlockEnum<'_, X> = BlockEnum::A { a: 42 };
assert_eq!(tmpl.render().unwrap(), "<a = 42>");
let tmpl: BlockEnum<'_, X> = BlockEnum::B { b: "second letter" };
assert_eq!(tmpl.render().unwrap(), "<b = second letter>");
let tmpl: BlockEnum<'_, X> = BlockEnum::C { c: X };
assert_eq!(tmpl.render().unwrap(), "<c = X>");
assert_eq!(
BlockEnum::<'_, X>::D.render().unwrap(),
"<d = &enum::test_enum_blocks::_::__Rinja__BlockEnum__D<enum::X>>"
);
}
#[derive(Debug)]
struct X;
impl Display for X {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("X")
}
}