Skip to content

Commit 6d4346d

Browse files
committed
feat(rust): test_fn now can expand rstest attr if the test mod has already use rstest
1 parent dd3180d commit 6d4346d

File tree

3 files changed

+108
-6
lines changed

3 files changed

+108
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ tmss! -> absl::flat_hash_map<std::string, std::string>
217217
| `.ts` | Switches indent's coding style between `CamelCase` and `snake_case`. | `indent` |
218218
| `.sc` | Wraps with `static_cast<>(?)`. | `any_expr` |
219219
| `.single` | Wraps with `ranges::views::single(?)`. | `any_expr` |
220-
| `.await` | Expands to `co_await ?`. | `any_expr` |
220+
| `.await` | Expands to `co_await ?`. | `any_expr` |
221221
| `.in` | Expands to `if (...find)` statements. | `any_expr` |
222222

223223
</details>

lua/luasnip-snippets/config.lua

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@
1717
---@class LSSnippets.Config.Snippet.Cpp
1818
---@field quick_type? LSSnippets.Config.Snippet.Cpp.QuickType
1919

20+
---@class LSSnippets.Config.Snippet.Rust
21+
---@field rstest_support? boolean
22+
2023
---@alias LSSnippets.Config.Snippet.DisableSnippets string[]
2124
---@alias LSSnippets.SupportLangs 'cpp'|'dart'|'lua'|'rust'|'nix'|'typescript'|'*'
2225

2326
---@class LSSnippets.Config.Snippet
2427
---@field lua? LSSnippets.Config.Snippet.Lua
2528
---@field cpp? LSSnippets.Config.Snippet.Cpp
29+
---@field rust? LSSnippets.Config.Snippet.Rust
2630

2731
---@class LSSnippets.Config
2832
---@field copyright_header? string
@@ -35,7 +39,7 @@ local config = {}
3539
---@param opts? LSSnippets.Config
3640
local function setup(opts)
3741
opts = opts or {}
38-
config = vim.tbl_extend("force", config, opts)
42+
config = vim.tbl_deep_extend("force", config, opts)
3943
end
4044

4145
---@return any

lua/luasnip-snippets/snippets/rust/test_fn.lua

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
local UtilsTS = require("luasnip-snippets.utils.treesitter")
2+
local Config = require("luasnip-snippets.config")
3+
local UtilsTbl = require("luasnip-snippets.utils.tbl")
24
local ls = require("luasnip")
35
local d = ls.dynamic_node
46
local sn = ls.snippet_node
@@ -8,6 +10,66 @@ local fmta = require("luasnip.extras.fmt").fmta
810
local i = require("luasnip-snippets.nodes").insert_node
911
local c = require("luasnip-snippets.nodes").choice_node
1012

13+
---@param left string[]
14+
---@param right string[]
15+
---@param sep string
16+
---@return string[]
17+
local function dot_concat(left, right, sep)
18+
local ret = {}
19+
20+
for _, l in ipairs(left) do
21+
for _, r in ipairs(right) do
22+
ret[#ret + 1] = l .. sep .. r
23+
end
24+
end
25+
26+
return ret
27+
end
28+
29+
---@param node TSNode?
30+
---@return string[]
31+
local function flat_scoped_use_list(source, node)
32+
if node == nil then
33+
return {}
34+
end
35+
if node:type() ~= "scoped_use_list" then
36+
return {
37+
vim.treesitter.get_node_text(node, source),
38+
}
39+
end
40+
41+
local path_nodes = node:field("path")
42+
if #path_nodes == 0 then
43+
return {}
44+
end
45+
46+
local paths = {}
47+
for _, path_node in ipairs(path_nodes) do
48+
vim.list_extend(paths, flat_scoped_use_list(source, path_node))
49+
end
50+
51+
local items = {}
52+
local name_nodes = node:field("name")
53+
for _, name_node in ipairs(name_nodes) do
54+
vim.list_extend(items, flat_scoped_use_list(source, name_node))
55+
end
56+
local list_nodes = node:field("list")
57+
local allow_list = {
58+
scoped_use_list = 1,
59+
use_wildcard = 1,
60+
identifier = 1,
61+
}
62+
for _, list_node in ipairs(list_nodes) do
63+
for child in list_node:iter_children() do
64+
if allow_list[child:type()] == 1 then
65+
vim.list_extend(items, flat_scoped_use_list(source, child))
66+
end
67+
end
68+
end
69+
70+
return dot_concat(paths, items, "::")
71+
end
72+
1173
local function inject_expanding_environment(_, _, match, captures)
1274
local row, col = unpack(vim.api.nvim_win_get_cursor(0))
1375
local buf = vim.api.nvim_get_current_buf()
@@ -73,6 +135,31 @@ local function inject_expanding_environment(_, _, match, captures)
73135
prev = prev:prev_sibling()
74136
end
75137
ret.env_override["ATTRIBUTES_ITEMS"] = attributes
138+
139+
if Config.get("snippet.rust.rstest_support") == true then
140+
-- check if this mod contains `use rstest::rstest;`
141+
local use_list = {}
142+
for _, body in ipairs(mod_item:field("body")) do
143+
for child in body:iter_children() do
144+
if child:type() == "use_declaration" then
145+
local nodes = child:field("argument")
146+
for _, use_node in ipairs(nodes) do
147+
local node_type = use_node:type()
148+
if node_type == "scoped_use_list" then
149+
vim.list_extend(
150+
use_list,
151+
flat_scoped_use_list(source, use_node)
152+
)
153+
elseif node_type == "scoped_identifier" then
154+
use_list[#use_list + 1] =
155+
vim.treesitter.get_node_text(use_node, source)
156+
end
157+
end
158+
end
159+
end
160+
end
161+
ret.env_override["USE_LIST"] = use_list
162+
end
76163
end
77164

78165
vim.api.nvim_win_set_cursor(0, { row, col })
@@ -102,6 +189,20 @@ return {
102189
end
103190

104191
if in_test_cfg and env.MOD_ITEM_NAME == "tests" then
192+
local test_fn_attrs = {
193+
t("#[test]"),
194+
t("#[tokio::test]"),
195+
}
196+
197+
if Config.get("snippet.rust.rstest_support") == true then
198+
local use_list = env["USE_LIST"] or {}
199+
if vim.list_contains(use_list, "rstest::rstest") then
200+
test_fn_attrs[#test_fn_attrs + 1] = t("#[rstest]")
201+
elseif vim.list_contains(use_list, "rstest::*") then
202+
test_fn_attrs[#test_fn_attrs + 1] = t("#[rstest]")
203+
end
204+
end
205+
105206
-- function item
106207
return sn(
107208
nil,
@@ -121,10 +222,7 @@ return {
121222
end
122223
end, { 2 }),
123224
name = i(1, "new_fn", { desc = "function name" }),
124-
attr = c(2, {
125-
t("#[test]"),
126-
t("#[tokio::test]"),
127-
}, { desc = "function attributes" }),
225+
attr = c(2, test_fn_attrs, { desc = "function attributes" }),
128226
body = i(0),
129227
}
130228
)

0 commit comments

Comments
 (0)