On Wed, Jan 7, 2026 at 11:31 AM Gary Guo <[email protected]> wrote: > > From: Gary Guo <[email protected]> > > Make use of `syn` to parse the module structurally and thus improve the > robustness of parsing. > > String interpolation is avoided by generating tokens directly using > `quote!`. > > Signed-off-by: Gary Guo <[email protected]>
Reviewed-by: Tamir Duberstein <[email protected]> > --- > rust/macros/kunit.rs | 274 +++++++++++++++++++------------------------ > rust/macros/lib.rs | 6 +- > 2 files changed, 123 insertions(+), 157 deletions(-) > > diff --git a/rust/macros/kunit.rs b/rust/macros/kunit.rs > index 5cd6aa5eef07d..afbc708cbdc50 100644 > --- a/rust/macros/kunit.rs > +++ b/rust/macros/kunit.rs > @@ -4,81 +4,50 @@ > //! > //! Copyright (c) 2023 José Expósito <[email protected]> > > -use std::collections::HashMap; > -use std::fmt::Write; > - > -use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; > - > -pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream > { > - let attr = attr.to_string(); > - > - if attr.is_empty() { > - panic!("Missing test name in `#[kunit_tests(test_name)]` macro") > - } > - > - if attr.len() > 255 { > - panic!("The test suite name `{attr}` exceeds the maximum length of > 255 bytes") > +use std::ffi::CString; > + > +use proc_macro2::TokenStream; > +use quote::{ > + format_ident, > + quote, > + ToTokens, // > +}; > +use syn::{ > + parse_quote, > + Error, > + Ident, > + Item, > + ItemMod, > + LitCStr, > + Result, // > +}; > + > +pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> > Result<TokenStream> { > + if test_suite.to_string().len() > 255 { > + return Err(Error::new_spanned( > + test_suite, > + "test suite names cannot exceed the maximum length of 255 bytes", > + )); > } > > - let mut tokens: Vec<_> = ts.into_iter().collect(); > - > - // Scan for the `mod` keyword. > - tokens > - .iter() > - .find_map(|token| match token { > - TokenTree::Ident(ident) => match ident.to_string().as_str() { > - "mod" => Some(true), > - _ => None, > - }, > - _ => None, > - }) > - .expect("`#[kunit_tests(test_name)]` attribute should only be > applied to modules"); > - > - // Retrieve the main body. The main body should be the last token tree. > - let body = match tokens.pop() { > - Some(TokenTree::Group(group)) if group.delimiter() == > Delimiter::Brace => group, > - _ => panic!("Cannot locate main body of module"), > + // We cannot handle modules that defer to another file (e.g. `mod foo;`). > + let Some((module_brace, module_items)) = module.content.take() else { > + Err(Error::new_spanned( > + module, > + "`#[kunit_tests(test_name)]` attribute should only be applied to > inline modules", > + ))? > }; > > - // Get the functions set as tests. Search for `[test]` -> `fn`. > - let mut body_it = body.stream().into_iter(); > - let mut tests = Vec::new(); > - let mut attributes: HashMap<String, TokenStream> = HashMap::new(); > - while let Some(token) = body_it.next() { > - match token { > - TokenTree::Punct(ref p) if p.as_char() == '#' => match > body_it.next() { > - Some(TokenTree::Group(g)) if g.delimiter() == > Delimiter::Bracket => { > - if let Some(TokenTree::Ident(name)) = > g.stream().into_iter().next() { > - // Collect attributes because we need to find which > are tests. We also > - // need to copy `cfg` attributes so tests can be > conditionally enabled. > - attributes > - .entry(name.to_string()) > - .or_default() > - .extend([token, TokenTree::Group(g)]); > - } > - continue; > - } > - _ => (), > - }, > - TokenTree::Ident(i) if i == "fn" && > attributes.contains_key("test") => { > - if let Some(TokenTree::Ident(test_name)) = body_it.next() { > - tests.push((test_name, > attributes.remove("cfg").unwrap_or_default())) > - } > - } > - > - _ => (), > - } > - attributes.clear(); > - } > + // Make the entire module gated behind `CONFIG_KUNIT`. > + module > + .attrs > + .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")])); > > - // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. > - let config_kunit = > "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); > - tokens.insert( > - 0, > - TokenTree::Group(Group::new(Delimiter::None, config_kunit)), > - ); > + let mut processed_items = Vec::new(); > + let mut test_cases = Vec::new(); > > // Generate the test KUnit test suite and a test case for each `#[test]`. > + // > // The code generated for the following test module: > // > // ``` > @@ -110,98 +79,93 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: > TokenStream) -> TokenStream { > // > // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); > // ``` > - let mut kunit_macros = "".to_owned(); > - let mut test_cases = "".to_owned(); > - let mut assert_macros = "".to_owned(); > - let path = crate::helpers::file(); > - let num_tests = tests.len(); > - for (test, cfg_attr) in tests { > - let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); > - // Append any `cfg` attributes the user might have written on their > tests so we don't > - // attempt to call them when they are `cfg`'d out. An extra `use` is > used here to reduce > - // the length of the assert message. > - let kunit_wrapper = format!( > - r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut > ::kernel::bindings::kunit) > - {{ > - (*_test).status = > ::kernel::bindings::kunit_status_KUNIT_SKIPPED; > - {cfg_attr} {{ > - (*_test).status = > ::kernel::bindings::kunit_status_KUNIT_SUCCESS; > - use ::kernel::kunit::is_test_result_ok; > - assert!(is_test_result_ok({test}())); > + // > + // Non-function items (e.g. imports) are preserved. > + for item in module_items { > + let Item::Fn(mut f) = item else { > + processed_items.push(item); > + continue; > + }; > + > + // TODO: Replace below with `extract_if` when MSRV is bumped above > 1.85. > + let before_len = f.attrs.len(); > + f.attrs.retain(|attr| !attr.path().is_ident("test")); > + if f.attrs.len() == before_len { > + processed_items.push(Item::Fn(f)); > + continue; > + } > + > + let test = f.sig.ident.clone(); > + > + // Retrieve `#[cfg]` applied on the function which needs to be > present on derived items too. > + let cfg_attrs: Vec<_> = f > + .attrs > + .iter() > + .filter(|attr| attr.path().is_ident("cfg")) > + .cloned() > + .collect(); > + > + // Before the test, override usual `assert!` and `assert_eq!` macros > with ones that call > + // KUnit instead. > + let test_str = test.to_string(); > + let path = crate::helpers::file(); > + processed_items.push(parse_quote! { > + #[allow(unused)] > + macro_rules! assert { > + ($cond:expr $(,)?) => {{ > + kernel::kunit_assert!(#test_str, #path, 0, $cond); > + }} > + } > + }); > + processed_items.push(parse_quote! { > + #[allow(unused)] > + macro_rules! assert_eq { > + ($left:expr, $right:expr $(,)?) => {{ > + kernel::kunit_assert_eq!(#test_str, #path, 0, $left, > $right); > }} > - }}"#, > + } > + }); > + > + // Add back the test item. > + processed_items.push(Item::Fn(f)); > + > + let kunit_wrapper_fn_name = > format_ident!("kunit_rust_wrapper_{test}"); > + let test_cstr = LitCStr::new( > + &CString::new(test_str.as_str()).expect("identifier cannot > contain NUL"), > + test.span(), > ); > - writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); > - writeln!( > - test_cases, > - " ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), > {kunit_wrapper_fn_name})," > - ) > - .unwrap(); > - writeln!( > - assert_macros, > - r#" > -/// Overrides the usual [`assert!`] macro with one that calls KUnit instead. > -#[allow(unused)] > -macro_rules! assert {{ > - ($cond:expr $(,)?) => {{{{ > - kernel::kunit_assert!("{test}", "{path}", 0, $cond); > - }}}} > -}} > - > -/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit > instead. > -#[allow(unused)] > -macro_rules! assert_eq {{ > - ($left:expr, $right:expr $(,)?) => {{{{ > - kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right); > - }}}} > -}} > - "# > - ) > - .unwrap(); > - } > + processed_items.push(parse_quote! { > + unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut > ::kernel::bindings::kunit) { > + (*_test).status = > ::kernel::bindings::kunit_status_KUNIT_SKIPPED; > > - writeln!(kunit_macros).unwrap(); > - writeln!( > - kunit_macros, > - "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = > [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", > - num_tests + 1 > - ) > - .unwrap(); > - > - writeln!( > - kunit_macros, > - "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" > - ) > - .unwrap(); > - > - // Remove the `#[test]` macros. > - // We do this at a token level, in order to preserve span information. > - let mut new_body = vec![]; > - let mut body_it = body.stream().into_iter(); > - > - while let Some(token) = body_it.next() { > - match token { > - TokenTree::Punct(ref c) if c.as_char() == '#' => match > body_it.next() { > - Some(TokenTree::Group(group)) if group.to_string() == > "[test]" => (), > - Some(next) => { > - new_body.extend([token, next]); > - } > - _ => { > - new_body.push(token); > + // Append any `cfg` attributes the user might have written > on their tests so we > + // don't attempt to call them when they are `cfg`'d out. An > extra `use` is used > + // here to reduce the length of the assert message. > + #(#cfg_attrs)* > + { > + (*_test).status = > ::kernel::bindings::kunit_status_KUNIT_SUCCESS; > + use ::kernel::kunit::is_test_result_ok; > + assert!(is_test_result_ok(#test())); > } > - }, > - _ => { > - new_body.push(token); > } > - } > - } > + }); > > - let mut final_body = TokenStream::new(); > - final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); > - final_body.extend(new_body); > - final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); > - > - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); > + test_cases.push(quote!( > + ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name) > + )); > + } > > - tokens.into_iter().collect() > + let num_tests_plus_1 = test_cases.len() + 1; > + processed_items.push(parse_quote! { > + static mut TEST_CASES: [::kernel::bindings::kunit_case; > #num_tests_plus_1] = [ > + #(#test_cases,)* > + ::kernel::kunit::kunit_case_null(), > + ]; > + }); > + processed_items.push(parse_quote! { > + ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES); > + }); > + > + module.content = Some((module_brace, processed_items)); > + Ok(module.to_token_stream()) > } > diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs > index 12467bfc703a8..75ac60abe6ffa 100644 > --- a/rust/macros/lib.rs > +++ b/rust/macros/lib.rs > @@ -481,6 +481,8 @@ pub fn paste(input: TokenStream) -> TokenStream { > /// } > /// ``` > #[proc_macro_attribute] > -pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { > - kunit::kunit_tests(attr.into(), ts.into()).into() > +pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream { > + kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input)) > + .unwrap_or_else(|e| e.into_compile_error()) > + .into() > } > -- > 2.51.2 >

