Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect fmt::Pointer implementations #328

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions impl/src/fmt/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use std::fmt;

use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_quote, spanned::Spanned as _};
use syn::{parse_quote, spanned::Spanned as _, token};

use crate::utils::{attr::ParseMultiple as _, Spanning};

use super::{trait_name_to_attribute_name, ContainerAttributes};
use super::{trait_name_to_attribute_name, ContainerAttributes, FmtArgument};

/// Expands a [`fmt::Display`]-like derive macro.
///
Expand Down Expand Up @@ -85,8 +85,9 @@ fn expand_struct(
trait_ident,
ident,
};

let expr = s.generate_expr()?;
let bounds = s.generate_bounds();
let body = s.generate_body()?;

let vars = s.fields.iter().enumerate().map(|(i, f)| {
let var = f.ident.clone().unwrap_or_else(|| format_ident!("_{i}"));
Expand All @@ -101,7 +102,7 @@ fn expand_struct(

let body = quote! {
#( #vars )*
#body
#expr
};

Ok((bounds, body))
Expand Down Expand Up @@ -143,7 +144,8 @@ fn expand_enum(
trait_ident,
ident,
};
let arm_body = v.generate_body()?;

let arm_body = v.generate_expr()?;
bounds.extend(v.generate_bounds());

let fields_idents =
Expand Down Expand Up @@ -216,22 +218,32 @@ struct Expansion<'a> {
}

impl<'a> Expansion<'a> {
/// Generates [`Display::fmt()`] implementation for a struct or an enum variant.
/// Generates [`Display::fmt()`] implementation expression for a struct or an enum variant.
///
/// # Errors
///
/// In case [`FmtAttribute`] is [`None`] and [`syn::Fields`] length is
/// greater than 1.
/// In case [`FmtAttribute`] is [`None`] and [`syn::Fields`] length is greater than 1.
///
/// [`Display::fmt()`]: fmt::Display::fmt()
/// [`FmtAttribute`]: super::FmtAttribute
fn generate_body(&self) -> syn::Result<TokenStream> {
fn generate_expr(&self) -> syn::Result<TokenStream> {
match &self.attrs.fmt {
Some(fmt) => {
Ok(if let Some((expr, trait_ident)) = fmt.transparent_call() {
quote! { derive_more::core::fmt::#trait_ident::fmt(&(#expr), __derive_more_f) }
} else {
quote! { derive_more::core::write!(__derive_more_f, #fmt) }
let mut fmt_expr = fmt.clone();
let additional_args = fmt.iter_used_fields(&self.fields).map(
|(name, _)| -> FmtArgument {
parse_quote! { #name = *#name }
},
);
fmt_expr.args.extend(additional_args);
if !fmt_expr.args.is_empty() { // TODO: Move into separate method.
fmt_expr.comma = Some(token::Comma::default());
}

quote! { derive_more::core::write!(__derive_more_f, #fmt_expr) }
})
}
None if self.fields.is_empty() => {
Expand All @@ -246,7 +258,7 @@ impl<'a> Expansion<'a> {
.fields
.iter()
.next()
.unwrap_or_else(|| unreachable!("count() == 1"));
.unwrap_or_else(|| unreachable!("fields.len() == 1"));
let ident = field.ident.clone().unwrap_or_else(|| format_ident!("_0"));
let trait_ident = self.trait_ident;

Expand Down
50 changes: 44 additions & 6 deletions impl/src/fmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl BoundsAttribute {
/// ```
///
/// [`fmt`]: std::fmt
#[derive(Debug)]
#[derive(Clone, Debug)]
struct FmtAttribute {
/// Interpolation [`syn::LitStr`].
///
Expand Down Expand Up @@ -196,8 +196,8 @@ impl FmtAttribute {
Some((expr, format_ident!("{trait_name}")))
}

/// Returns an [`Iterator`] over bounded [`syn::Type`]s (and correspondent trait names) by this
/// [`FmtAttribute`].
/// Returns an [`Iterator`] over bounded [`syn::Type`]s (and correspondent trait names) of the
/// provided [`syn::Fields`] used by this [`FmtAttribute`].
fn bounded_types<'a>(
&'a self,
fields: &'a syn::Fields,
Expand Down Expand Up @@ -235,6 +235,45 @@ impl FmtAttribute {
})
}

/// Returns an [`Iterator`] over the provided [`syn::Field`]s used by this [`FmtAttribute`],
/// along with the correspondent [`syn::Ident`] it's referred by in this [`FmtAttribute`].
fn iter_used_fields<'a>(
&'a self,
fields: &'a syn::Fields,
) -> impl Iterator<Item = (syn::Ident, &'a syn::Field)> {
let placeholders = Placeholder::parse_fmt_string(&self.lit.value());

// We ignore unknown fields, as compiler will produce better error messages.
placeholders.into_iter().filter_map(move |placeholder| {
let name = match &placeholder.arg {
Parameter::Named(name) => self
.args
.iter()
.find_map(|a| (a.alias()? == &name).then_some(&a.expr))
.map_or(Some(format_ident!("{name}")), |expr| expr.ident().cloned())?,
Parameter::Positional(i) => self
.args
.iter()
.nth(*i)
.and_then(|a| a.expr.ident().filter(|_| a.alias.is_none()))?
.clone(),
};
let position = name.to_string().strip_prefix('_').and_then(|s| s.parse().ok());

let field = match (&fields, position) {
(syn::Fields::Unnamed(f), Some(i)) => {
f.unnamed.iter().nth(i)
}
(syn::Fields::Named(f), None) => f.named.iter().find_map(|f| {
f.ident.as_ref().filter(|s| **s == name).map(|_| f)
}),
_ => None,
}?;

Some((name, field))
})
}

/// Errors in case legacy syntax is encountered: `fmt = "...", (arg),*`.
fn check_legacy_fmt(input: ParseStream<'_>) -> syn::Result<()> {
let fork = input.fork();
Expand Down Expand Up @@ -278,11 +317,10 @@ impl FmtAttribute {
}
}

/// Representation of a [named parameter][1] (`identifier '=' expression`) in
/// in a [`FmtAttribute`].
/// Representation of a [named parameter][1] (`identifier '=' expression`) in a [`FmtAttribute`].
///
/// [1]: https://doc.rust-lang.org/stable/std/fmt/index.html#named-parameters
#[derive(Debug)]
#[derive(Clone, Debug)]
struct FmtArgument {
/// `identifier =` [`Ident`].
///
Expand Down
67 changes: 65 additions & 2 deletions tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ mod structs {
assert_eq!(format!("{:03?}", UpperHex), "00B");
assert_eq!(format!("{:07?}", LowerExp), "03.15e0");
assert_eq!(format!("{:07?}", UpperExp), "03.15E0");
assert_eq!(format!("{:018?}", Pointer).len(), 18);
assert_eq!(format!("{:018?}", Pointer), format!("{POINTER:018p}"));
}

mod omitted {
Expand Down Expand Up @@ -246,6 +246,35 @@ mod structs {
"Struct {\n field: 0.0,\n}",
);
}

mod pointer {
#[cfg(not(feature = "std"))]
use alloc::format;

use derive_more::Debug;

#[derive(Debug)]
struct Tuple<'a>(#[debug("{_0:p}.{:p}", self.0)] &'a i32);

#[derive(Debug)]
struct Struct<'a> {
#[debug("{field:p}.{:p}", self.field)]
field: &'a i32,
}

#[test]
fn assert() {
let a = 42;
assert_eq!(
format!("{:?}", Tuple(&a)),
format!("Tuple({0:p}.{0:p})", &a),
);
assert_eq!(
format!("{:?}", Struct { field: &a }),
format!("Struct {{ field: {0:p}.{0:p} }}", &a),
);
}
}
}

mod ignore {
Expand Down Expand Up @@ -527,6 +556,37 @@ mod structs {
assert_eq!(format!("{:?}", Tuple(10, true)), "10 * true");
assert_eq!(format!("{:?}", Struct { a: 10, b: true }), "10 * true");
}

mod pointer {
#[cfg(not(feature = "std"))]
use alloc::format;

use derive_more::Debug;

#[derive(Debug)]
#[debug("{_0:p} * {_1:p}", _0 = self.0)]
struct Tuple<'a, 'b>(&'a u8, &'b bool);

#[derive(Debug)]
#[debug("{a:p} * {b:p}", a = self.a)]
struct Struct<'a, 'b> {
a: &'a u8,
b: &'b bool,
}

#[test]
fn assert() {
let (a, b) = (10, true);
assert_eq!(
format!("{:?}", Tuple(&a, &b)),
format!("{:p} * {:p}", &a, &b),
);
assert_eq!(
format!("{:?}", Struct { a: &a, b: &b }),
format!("{:p} * {:p}", &a, &b),
);
}
}
}

mod ignore {
Expand Down Expand Up @@ -677,7 +737,10 @@ mod enums {
assert_eq!(format!("{:03?}", Unit::UpperHex), "00B");
assert_eq!(format!("{:07?}", Unit::LowerExp), "03.15e0");
assert_eq!(format!("{:07?}", Unit::UpperExp), "03.15E0");
assert_eq!(format!("{:018?}", Unit::Pointer).len(), 18);
assert_eq!(
format!("{:018?}", Unit::Pointer),
format!("{POINTER:018p}"),
);
}

mod omitted {
Expand Down
Loading