242 lines
7.8 KiB
Rust
242 lines
7.8 KiB
Rust
use proc_macro2::TokenStream;
|
|
use quote::{quote, quote_spanned};
|
|
use syn::spanned::Spanned;
|
|
|
|
use crate::attrs::*;
|
|
use crate::utils::*;
|
|
|
|
type Variants = syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>;
|
|
|
|
/// Defines and implements `config_type` enum.
|
|
pub fn define_config_type_on_enum(em: &syn::ItemEnum) -> syn::Result<TokenStream> {
|
|
let syn::ItemEnum {
|
|
vis,
|
|
enum_token,
|
|
ident,
|
|
generics,
|
|
variants,
|
|
..
|
|
} = em;
|
|
|
|
let mod_name_str = format!("__define_config_type_on_enum_{}", ident);
|
|
let mod_name = syn::Ident::new(&mod_name_str, ident.span());
|
|
let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,));
|
|
|
|
let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants);
|
|
let impl_from_str = impl_from_str(&em.ident, &em.variants);
|
|
let impl_display = impl_display(&em.ident, &em.variants);
|
|
let impl_serde = impl_serde(&em.ident, &em.variants);
|
|
let impl_deserialize = impl_deserialize(&em.ident, &em.variants);
|
|
|
|
Ok(quote! {
|
|
#[allow(non_snake_case)]
|
|
mod #mod_name {
|
|
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
|
pub #enum_token #ident #generics { #variants }
|
|
#impl_display
|
|
#impl_doc_hint
|
|
#impl_from_str
|
|
#impl_serde
|
|
#impl_deserialize
|
|
}
|
|
#vis use #mod_name::#ident;
|
|
})
|
|
}
|
|
|
|
/// Remove attributes specific to `config_proc_macro` from enum variant fields.
|
|
fn process_variant(variant: &syn::Variant) -> TokenStream {
|
|
let metas = variant
|
|
.attrs
|
|
.iter()
|
|
.filter(|attr| !is_doc_hint(attr) && !is_config_value(attr) && !is_unstable_variant(attr));
|
|
let attrs = fold_quote(metas, |meta| quote!(#meta));
|
|
let syn::Variant { ident, fields, .. } = variant;
|
|
quote!(#attrs #ident #fields)
|
|
}
|
|
|
|
/// Return the correct syntax to pattern match on the enum variant, discarding all
|
|
/// internal field data.
|
|
fn fields_in_variant(variant: &syn::Variant) -> TokenStream {
|
|
// With thanks to https://stackoverflow.com/a/65182902
|
|
match &variant.fields {
|
|
syn::Fields::Unnamed(_) => quote_spanned! { variant.span() => (..) },
|
|
syn::Fields::Unit => quote_spanned! { variant.span() => },
|
|
syn::Fields::Named(_) => quote_spanned! { variant.span() => {..} },
|
|
}
|
|
}
|
|
|
|
fn impl_doc_hint(ident: &syn::Ident, variants: &Variants) -> TokenStream {
|
|
let doc_hint = variants
|
|
.iter()
|
|
.map(doc_hint_of_variant)
|
|
.collect::<Vec<_>>()
|
|
.join("|");
|
|
let doc_hint = format!("[{}]", doc_hint);
|
|
|
|
let variant_stables = variants
|
|
.iter()
|
|
.map(|v| (&v.ident, fields_in_variant(&v), !unstable_of_variant(v)));
|
|
let match_patterns = fold_quote(variant_stables, |(v, fields, stable)| {
|
|
quote! {
|
|
#ident::#v #fields => #stable,
|
|
}
|
|
});
|
|
quote! {
|
|
use crate::config::ConfigType;
|
|
impl ConfigType for #ident {
|
|
fn doc_hint() -> String {
|
|
#doc_hint.to_owned()
|
|
}
|
|
fn stable_variant(&self) -> bool {
|
|
match self {
|
|
#match_patterns
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn impl_display(ident: &syn::Ident, variants: &Variants) -> TokenStream {
|
|
let vs = variants
|
|
.iter()
|
|
.filter(|v| is_unit(v))
|
|
.map(|v| (config_value_of_variant(v), &v.ident));
|
|
let match_patterns = fold_quote(vs, |(s, v)| {
|
|
quote! {
|
|
#ident::#v => write!(f, "{}", #s),
|
|
}
|
|
});
|
|
quote! {
|
|
use std::fmt;
|
|
impl fmt::Display for #ident {
|
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
|
match self {
|
|
#match_patterns
|
|
_ => unimplemented!(),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn impl_from_str(ident: &syn::Ident, variants: &Variants) -> TokenStream {
|
|
let vs = variants
|
|
.iter()
|
|
.filter(|v| is_unit(v))
|
|
.map(|v| (config_value_of_variant(v), &v.ident));
|
|
let if_patterns = fold_quote(vs, |(s, v)| {
|
|
quote! {
|
|
if #s.eq_ignore_ascii_case(s) {
|
|
return Ok(#ident::#v);
|
|
}
|
|
}
|
|
});
|
|
let mut err_msg = String::from("Bad variant, expected one of:");
|
|
for v in variants.iter().filter(|v| is_unit(v)) {
|
|
err_msg.push_str(&format!(" `{}`", v.ident));
|
|
}
|
|
|
|
quote! {
|
|
impl ::std::str::FromStr for #ident {
|
|
type Err = &'static str;
|
|
|
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
|
#if_patterns
|
|
return Err(#err_msg);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn doc_hint_of_variant(variant: &syn::Variant) -> String {
|
|
let mut text = find_doc_hint(&variant.attrs).unwrap_or(variant.ident.to_string());
|
|
if unstable_of_variant(&variant) {
|
|
text.push_str(" (unstable)")
|
|
};
|
|
text
|
|
}
|
|
|
|
fn config_value_of_variant(variant: &syn::Variant) -> String {
|
|
find_config_value(&variant.attrs).unwrap_or(variant.ident.to_string())
|
|
}
|
|
|
|
fn unstable_of_variant(variant: &syn::Variant) -> bool {
|
|
any_unstable_variant(&variant.attrs)
|
|
}
|
|
|
|
fn impl_serde(ident: &syn::Ident, variants: &Variants) -> TokenStream {
|
|
let arms = fold_quote(variants.iter(), |v| {
|
|
let v_ident = &v.ident;
|
|
let pattern = match v.fields {
|
|
syn::Fields::Named(..) => quote!(#ident::v_ident{..}),
|
|
syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)),
|
|
syn::Fields::Unit => quote!(#ident::#v_ident),
|
|
};
|
|
let option_value = config_value_of_variant(v);
|
|
quote! {
|
|
#pattern => serializer.serialize_str(&#option_value),
|
|
}
|
|
});
|
|
|
|
quote! {
|
|
impl ::serde::ser::Serialize for #ident {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: ::serde::ser::Serializer,
|
|
{
|
|
use serde::ser::Error;
|
|
match self {
|
|
#arms
|
|
_ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Currently only unit variants are supported.
|
|
fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream {
|
|
let supported_vs = variants.iter().filter(|v| is_unit(v));
|
|
let if_patterns = fold_quote(supported_vs, |v| {
|
|
let config_value = config_value_of_variant(v);
|
|
let variant_ident = &v.ident;
|
|
quote! {
|
|
if #config_value.eq_ignore_ascii_case(s) {
|
|
return Ok(#ident::#variant_ident);
|
|
}
|
|
}
|
|
});
|
|
|
|
let supported_vs = variants.iter().filter(|v| is_unit(v));
|
|
let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,));
|
|
|
|
quote! {
|
|
impl<'de> serde::de::Deserialize<'de> for #ident {
|
|
fn deserialize<D>(d: D) -> Result<Self, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
use serde::de::{Error, Visitor};
|
|
use std::marker::PhantomData;
|
|
use std::fmt;
|
|
struct StringOnly<T>(PhantomData<T>);
|
|
impl<'de, T> Visitor<'de> for StringOnly<T>
|
|
where T: serde::Deserializer<'de> {
|
|
type Value = String;
|
|
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
formatter.write_str("string")
|
|
}
|
|
fn visit_str<E>(self, value: &str) -> Result<String, E> {
|
|
Ok(String::from(value))
|
|
}
|
|
}
|
|
let s = &d.deserialize_string(StringOnly::<D>(PhantomData))?;
|
|
|
|
#if_patterns
|
|
|
|
static ALLOWED: &'static[&str] = &[#allowed];
|
|
Err(D::Error::unknown_variant(&s, ALLOWED))
|
|
}
|
|
}
|
|
}
|
|
}
|