diff --git a/lua/neotest-rust/dap.lua b/lua/neotest-rust/dap.lua index 886ca96..174dce9 100644 --- a/lua/neotest-rust/dap.lua +++ b/lua/neotest-rust/dap.lua @@ -83,15 +83,18 @@ end -- Determine if mod is in .rs or /mod.rs local function construct_mod_path(src_path, mod_name) local match_str = "(.-)[^\\/]-%.?(%w+)%.?[^\\/]*$" - local abs_path, _ = string.match(src_path, match_str) + local abs_path, parent_mod = string.match(src_path, match_str) local mod_file = abs_path .. mod_name .. ".rs" local mod_dir = abs_path .. mod_name .. sep .. "mod.rs" + local child_mod = abs_path .. parent_mod .. sep .. mod_name .. ".rs" if util.file_exists(mod_file) then return mod_file elseif util.file_exists(mod_dir) then return mod_dir + elseif util.file_exists(child_mod) then + return child_mod end return nil diff --git a/tests/dap_spec.lua b/tests/dap_spec.lua index 158f85b..344c3af 100644 --- a/tests/dap_spec.lua +++ b/tests/dap_spec.lua @@ -61,6 +61,13 @@ describe("get_test_binary", function() assert.equal(expected, actual) end) + async.it("returns the test binary for src/parent/child.rs", function() + local expected = main_actual + local actual = dap.get_test_binary(root, root .. "/src/parent/child.rs") + + assert.equal(expected, actual) + end) + async.it("returns the test binary for tests/test_it.rs", function() assert(test_it_actual) local expected = root .. "/target/debug/deps/test_it-" diff --git a/tests/data/simple-package/src/main.rs b/tests/data/simple-package/src/main.rs index e3457dd..bb0a2f4 100644 --- a/tests/data/simple-package/src/main.rs +++ b/tests/data/simple-package/src/main.rs @@ -1,4 +1,5 @@ mod mymod; +mod parent; fn main() { println!("Hello, world!"); diff --git a/tests/data/simple-package/src/parent.rs b/tests/data/simple-package/src/parent.rs new file mode 100644 index 0000000..21d0b0d --- /dev/null +++ b/tests/data/simple-package/src/parent.rs @@ -0,0 +1 @@ +pub mod child; diff --git a/tests/data/simple-package/src/parent/child.rs b/tests/data/simple-package/src/parent/child.rs new file mode 100644 index 0000000..c1e9866 --- /dev/null +++ b/tests/data/simple-package/src/parent/child.rs @@ -0,0 +1,7 @@ +#[cfg(test)] +mod tests { + #[test] + fn math() { + assert_eq!(1 + 1, 2); + } +} diff --git a/tests/init_spec.lua b/tests/init_spec.lua index a6c35a6..73e33fc 100644 --- a/tests/init_spec.lua +++ b/tests/init_spec.lua @@ -26,7 +26,7 @@ describe("discover_positions", function() id = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", name = "main.rs", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 0, 0, 25, 0 }, + range = { 0, 0, 26, 0 }, type = "file", }, { @@ -34,7 +34,7 @@ describe("discover_positions", function() id = "tests", name = "tests", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 7, 0, 24, 1 }, + range = { 8, 0, 25, 1 }, type = "namespace", }, { @@ -42,7 +42,7 @@ describe("discover_positions", function() id = "tests::basic_math", name = "basic_math", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 9, 4, 11, 5 }, + range = { 10, 4, 12, 5 }, type = "test", }, }, @@ -51,7 +51,7 @@ describe("discover_positions", function() id = "tests::failed_math", name = "failed_math", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 14, 4, 16, 5 }, + range = { 15, 4, 17, 5 }, type = "test", }, }, @@ -60,7 +60,7 @@ describe("discover_positions", function() id = "tests::nested", name = "nested", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 18, 4, 23, 5 }, + range = { 19, 4, 24, 5 }, type = "namespace", }, { @@ -68,7 +68,7 @@ describe("discover_positions", function() id = "tests::nested::nested_math", name = "nested_math", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 20, 8, 22, 9 }, + range = { 21, 8, 23, 9 }, type = "test", }, },