11local UtilsTS = require (" luasnip-snippets.utils.treesitter" )
2+ local Config = require (" luasnip-snippets.config" )
3+ local UtilsTbl = require (" luasnip-snippets.utils.tbl" )
24local ls = require (" luasnip" )
35local d = ls .dynamic_node
46local sn = ls .snippet_node
@@ -8,6 +10,66 @@ local fmta = require("luasnip.extras.fmt").fmta
810local i = require (" luasnip-snippets.nodes" ).insert_node
911local c = require (" luasnip-snippets.nodes" ).choice_node
1012
13+ --- @param left string[]
14+ --- @param right string[]
15+ --- @param sep string
16+ --- @return string[]
17+ local function dot_concat (left , right , sep )
18+ local ret = {}
19+
20+ for _ , l in ipairs (left ) do
21+ for _ , r in ipairs (right ) do
22+ ret [# ret + 1 ] = l .. sep .. r
23+ end
24+ end
25+
26+ return ret
27+ end
28+
29+ --- @param node TSNode ?
30+ --- @return string[]
31+ local function flat_scoped_use_list (source , node )
32+ if node == nil then
33+ return {}
34+ end
35+ if node :type () ~= " scoped_use_list" then
36+ return {
37+ vim .treesitter .get_node_text (node , source ),
38+ }
39+ end
40+
41+ local path_nodes = node :field (" path" )
42+ if # path_nodes == 0 then
43+ return {}
44+ end
45+
46+ local paths = {}
47+ for _ , path_node in ipairs (path_nodes ) do
48+ vim .list_extend (paths , flat_scoped_use_list (source , path_node ))
49+ end
50+
51+ local items = {}
52+ local name_nodes = node :field (" name" )
53+ for _ , name_node in ipairs (name_nodes ) do
54+ vim .list_extend (items , flat_scoped_use_list (source , name_node ))
55+ end
56+ local list_nodes = node :field (" list" )
57+ local allow_list = {
58+ scoped_use_list = 1 ,
59+ use_wildcard = 1 ,
60+ identifier = 1 ,
61+ }
62+ for _ , list_node in ipairs (list_nodes ) do
63+ for child in list_node :iter_children () do
64+ if allow_list [child :type ()] == 1 then
65+ vim .list_extend (items , flat_scoped_use_list (source , child ))
66+ end
67+ end
68+ end
69+
70+ return dot_concat (paths , items , " ::" )
71+ end
72+
1173local function inject_expanding_environment (_ , _ , match , captures )
1274 local row , col = unpack (vim .api .nvim_win_get_cursor (0 ))
1375 local buf = vim .api .nvim_get_current_buf ()
@@ -73,6 +135,31 @@ local function inject_expanding_environment(_, _, match, captures)
73135 prev = prev :prev_sibling ()
74136 end
75137 ret .env_override [" ATTRIBUTES_ITEMS" ] = attributes
138+
139+ if Config .get (" snippet.rust.rstest_support" ) == true then
140+ -- check if this mod contains `use rstest::rstest;`
141+ local use_list = {}
142+ for _ , body in ipairs (mod_item :field (" body" )) do
143+ for child in body :iter_children () do
144+ if child :type () == " use_declaration" then
145+ local nodes = child :field (" argument" )
146+ for _ , use_node in ipairs (nodes ) do
147+ local node_type = use_node :type ()
148+ if node_type == " scoped_use_list" then
149+ vim .list_extend (
150+ use_list ,
151+ flat_scoped_use_list (source , use_node )
152+ )
153+ elseif node_type == " scoped_identifier" then
154+ use_list [# use_list + 1 ] =
155+ vim .treesitter .get_node_text (use_node , source )
156+ end
157+ end
158+ end
159+ end
160+ end
161+ ret .env_override [" USE_LIST" ] = use_list
162+ end
76163 end
77164
78165 vim .api .nvim_win_set_cursor (0 , { row , col })
@@ -102,6 +189,20 @@ return {
102189 end
103190
104191 if in_test_cfg and env .MOD_ITEM_NAME == " tests" then
192+ local test_fn_attrs = {
193+ t (" #[test]" ),
194+ t (" #[tokio::test]" ),
195+ }
196+
197+ if Config .get (" snippet.rust.rstest_support" ) == true then
198+ local use_list = env [" USE_LIST" ] or {}
199+ if vim .list_contains (use_list , " rstest::rstest" ) then
200+ test_fn_attrs [# test_fn_attrs + 1 ] = t (" #[rstest]" )
201+ elseif vim .list_contains (use_list , " rstest::*" ) then
202+ test_fn_attrs [# test_fn_attrs + 1 ] = t (" #[rstest]" )
203+ end
204+ end
205+
105206 -- function item
106207 return sn (
107208 nil ,
@@ -121,10 +222,7 @@ return {
121222 end
122223 end , { 2 }),
123224 name = i (1 , " new_fn" , { desc = " function name" }),
124- attr = c (2 , {
125- t (" #[test]" ),
126- t (" #[tokio::test]" ),
127- }, { desc = " function attributes" }),
225+ attr = c (2 , test_fn_attrs , { desc = " function attributes" }),
128226 body = i (0 ),
129227 }
130228 )
0 commit comments