|
| 1 | +local ls = require("luasnip") |
| 2 | +local UtilsTS = require("luasnip-snippets.utils.treesitter") |
| 3 | +local d = ls.dynamic_node |
| 4 | +local sn = ls.snippet_node |
| 5 | +local t = ls.text_node |
| 6 | +local f = ls.function_node |
| 7 | +local fmta = require("luasnip.extras.fmt").fmta |
| 8 | +local CppCommons = require("luasnip-snippets.snippets.cpp.commons") |
| 9 | +local i = ls.insert_node |
| 10 | +local c = ls.choice_node |
| 11 | + |
| 12 | +---@class LSSnippets.Cpp.Fn.Env |
| 13 | +---@field CPP_ARGUMENT_START { [1]: number, [2]: number }? |
| 14 | +---@field CPP_FUNCTION_BODY_START { [1]: number, [2]: number }? |
| 15 | +---@field CPP_CLASS_BODY_START { [1]: number, [2]: number }? |
| 16 | +---@field CPP_IN_HEADER_FILE boolean |
| 17 | +---@field CPP_IN_QUALIFIED_FUNCTION boolean |
| 18 | + |
| 19 | +---Returns the start pos of a `TSNode` |
| 20 | +---@param node TSNode? |
| 21 | +---@return { [1]: number, [2]: number }? |
| 22 | +local function start_pos(node) |
| 23 | + if node == nil then |
| 24 | + return nil |
| 25 | + end |
| 26 | + local start_row, start_col, _, _ = vim.treesitter.get_node_range(node) |
| 27 | + return { start_row, start_col } |
| 28 | +end |
| 29 | + |
| 30 | +---Returns if the node's declarator is qualified or not. |
| 31 | +---@param node TSNode? `function_definition` node |
| 32 | +---@return boolean |
| 33 | +local function is_qualified_function(node) |
| 34 | + if node == nil then |
| 35 | + return false |
| 36 | + end |
| 37 | + print(node:type()) |
| 38 | + assert(node:type() == "function_definition") |
| 39 | + local declarators = node:field("declarator") |
| 40 | + if declarators == nil or #declarators == 0 then |
| 41 | + return false |
| 42 | + end |
| 43 | + local declarator = declarators[1] |
| 44 | + print(declarator:type()) |
| 45 | + assert(declarator:type() == "function_declarator") |
| 46 | + declarators = declarator:field("declarator") |
| 47 | + if declarators == nil or #declarators == 0 then |
| 48 | + return false |
| 49 | + end |
| 50 | + declarator = declarators[1] |
| 51 | + print(declarator:type()) |
| 52 | + if declarator:type() == "qualified_identifier" then |
| 53 | + return true |
| 54 | + end |
| 55 | + return false |
| 56 | +end |
| 57 | + |
| 58 | +local function inject_expanding_environment(_, line_to_cursor, match, captures) |
| 59 | + local row, col = unpack(vim.api.nvim_win_get_cursor(0)) |
| 60 | + local buf = vim.api.nvim_get_current_buf() |
| 61 | + |
| 62 | + return UtilsTS.invoke_after_reparse_buffer(buf, match, function(parser, _) |
| 63 | + local pos = { |
| 64 | + row - 1, |
| 65 | + col - #match, |
| 66 | + } |
| 67 | + local node = parser:named_node_for_range { |
| 68 | + pos[1], |
| 69 | + pos[2], |
| 70 | + pos[1], |
| 71 | + pos[2], |
| 72 | + } |
| 73 | + |
| 74 | + local ret = { |
| 75 | + trigger = match, |
| 76 | + capture = captures, |
| 77 | + env_override = { |
| 78 | + CPP_ARGUMENT_START = start_pos(UtilsTS.find_first_parent(node, { |
| 79 | + "argument_list", |
| 80 | + "parameter_list", |
| 81 | + })), |
| 82 | + CPP_FUNCTION_BODY_START = start_pos(UtilsTS.find_first_parent(node, { |
| 83 | + "function_definition", |
| 84 | + "lambda_expression", |
| 85 | + "field_declaration", |
| 86 | + })), |
| 87 | + CPP_CLASS_BODY_START = start_pos(UtilsTS.find_first_parent(node, { |
| 88 | + "struct_specifier", |
| 89 | + "class_specifier", |
| 90 | + })), |
| 91 | + CPP_IN_HEADER_FILE = CppCommons.in_header_file(), |
| 92 | + CPP_IN_QUALIFIED_FUNCTION = is_qualified_function( |
| 93 | + UtilsTS.find_first_parent(node, { |
| 94 | + "function_definition", |
| 95 | + }) |
| 96 | + ), |
| 97 | + }, |
| 98 | + } |
| 99 | + |
| 100 | + vim.api.nvim_win_set_cursor(0, { row, col }) |
| 101 | + return ret |
| 102 | + end) |
| 103 | +end |
| 104 | + |
| 105 | +---@param env LSSnippets.Cpp.Fn.Env |
| 106 | +local function make_lambda_snippet_node(env) |
| 107 | + local captures = t("&") |
| 108 | + if env.CPP_CLASS_BODY_START or env.CPP_IN_QUALIFIED_FUNCTION then |
| 109 | + -- inside a member function |
| 110 | + captures = c(3, { |
| 111 | + t("this, &"), |
| 112 | + t("this"), |
| 113 | + t("&"), |
| 114 | + }) |
| 115 | + end |
| 116 | + |
| 117 | + local fmt_args = { |
| 118 | + captures = captures, |
| 119 | + body = i(0), |
| 120 | + specifier = c(1, { |
| 121 | + t(""), |
| 122 | + t(" mutable"), |
| 123 | + }), |
| 124 | + args = i(2), |
| 125 | + } |
| 126 | + |
| 127 | + return sn( |
| 128 | + nil, |
| 129 | + fmta( |
| 130 | + [[ |
| 131 | + [<captures>](<args>)<specifier> { |
| 132 | + <body> |
| 133 | + } |
| 134 | + ]], |
| 135 | + fmt_args |
| 136 | + ) |
| 137 | + ) |
| 138 | +end |
| 139 | + |
| 140 | +---@param env LSSnippets.Cpp.Fn.Env |
| 141 | +local function make_function_snippet_node(env) |
| 142 | + local fmt_args = { |
| 143 | + body = i(0), |
| 144 | + inline_inline = t(""), |
| 145 | + } |
| 146 | + local storage_specifiers = { |
| 147 | + t(""), |
| 148 | + t("static "), |
| 149 | + } |
| 150 | + if not env.CPP_IN_HEADER_FILE then |
| 151 | + storage_specifiers[#storage_specifiers + 1] = t("inline ") |
| 152 | + storage_specifiers[#storage_specifiers + 1] = t("static inline ") |
| 153 | + else |
| 154 | + fmt_args.inline_inline = t("inline ") |
| 155 | + end |
| 156 | + |
| 157 | + local specifiers = { |
| 158 | + t(""), |
| 159 | + t(" noexcept"), |
| 160 | + } |
| 161 | + if env.CPP_CLASS_BODY_START then |
| 162 | + specifiers[#specifiers + 1] = t(" const") |
| 163 | + specifiers[#specifiers + 1] = t(" const noexcept") |
| 164 | + end |
| 165 | + fmt_args.storage_specifier = |
| 166 | + c(1, storage_specifiers, { desc = "storage specifier" }) |
| 167 | + fmt_args.ret = i(2, "auto", { desc = "return type" }) |
| 168 | + fmt_args.name = i(3, "name", { desc = "function name" }) |
| 169 | + fmt_args.args = i(4, "args", { desc = "function arguments" }) |
| 170 | + fmt_args.specifier = c(5, specifiers, { desc = "specifier" }) |
| 171 | + return sn( |
| 172 | + nil, |
| 173 | + fmta( |
| 174 | + [[ |
| 175 | + <storage_specifier><inline_inline>auto <name>(<args>)<specifier> ->> <ret> { |
| 176 | + <body> |
| 177 | + } |
| 178 | + ]], |
| 179 | + fmt_args |
| 180 | + ) |
| 181 | + ) |
| 182 | +end |
| 183 | + |
| 184 | +return { |
| 185 | + ls.s( |
| 186 | + { |
| 187 | + trig = "fn", |
| 188 | + wordTrig = true, |
| 189 | + name = "(fn) Function-Definition/Lambda", |
| 190 | + resolveExpandParams = inject_expanding_environment, |
| 191 | + }, |
| 192 | + d(1, function(_, parent) |
| 193 | + local env = parent.env |
| 194 | + local last_type, last_type_row, last_type_col |
| 195 | + local keys = { |
| 196 | + "CPP_ARGUMENT_START", |
| 197 | + "CPP_FUNCTION_BODY_START", |
| 198 | + "CPP_CLASS_BODY_START", |
| 199 | + } |
| 200 | + for _, key in ipairs(keys) do |
| 201 | + if env[key] ~= nil then |
| 202 | + if last_type == nil then |
| 203 | + last_type = key |
| 204 | + last_type_row = env[key][1] |
| 205 | + last_type_col = env[key][2] |
| 206 | + else |
| 207 | + if |
| 208 | + last_type_row < env[key][1] |
| 209 | + or (last_type_row == env[key][1] and last_type_col < env[key][2]) |
| 210 | + then |
| 211 | + last_type = key |
| 212 | + last_type_row = env[key][1] |
| 213 | + last_type_col = env[key][2] |
| 214 | + end |
| 215 | + end |
| 216 | + end |
| 217 | + end |
| 218 | + |
| 219 | + if |
| 220 | + last_type == "CPP_ARGUMENT_START" |
| 221 | + or last_type == "CPP_FUNCTION_BODY_START" |
| 222 | + then |
| 223 | + return make_lambda_snippet_node(env) |
| 224 | + else |
| 225 | + return make_function_snippet_node(env) |
| 226 | + end |
| 227 | + end, {}) |
| 228 | + ), |
| 229 | +} |
0 commit comments