@@ -153,13 +153,16 @@ ei_components = (
153153 index_constants = " indexes constants in the expression tree" => _check_index_constants,
154154 has_operators = " checks if the expression has operators" => _check_has_operators,
155155 has_constants = " checks if the expression has constants" => _check_has_constants,
156- get_constants = " gets constants from the expression tree" => _check_get_constants,
157- set_constants! = " sets constants in the expression tree" => _check_set_constants!,
156+ get_constants = (" gets constants from the expression tree, returning a tuple of: " *
157+ " (1) a flat vector of the constants, and (2) an reference object that " *
158+ " can be used by `set_constants!` to efficiently set them back" ) => _check_get_constants,
159+ set_constants! = (" sets constants in the expression tree, given: " *
160+ " (1) a flat vector of constants, (2) the expression, and " *
161+ " (3) the reference object produced by `get_constants`" ) => _check_set_constants!,
158162 string_tree = " returns a string representation of the expression tree" => _check_string_tree,
159163 default_node_type = " returns the default node type for the expression" => _check_default_node,
160164 constructorof = " gets the constructor function for a type" => _check_constructorof,
161- tree_mapreduce = " applies a function across the tree" => _check_tree_mapreduce
162- # TODO : add extract_gradient(gradient, ex::AbstractExpression)
165+ tree_mapreduce = " applies a function across the tree" => _check_tree_mapreduce,
163166 )
164167)
165168ei_description = (
@@ -333,10 +336,14 @@ ni_components = (
333336 count_constants = " counts the number of constants" => _check_count_constants,
334337 filter_map = " applies a filter and map function to the tree" => _check_filter_map,
335338 has_constants = " checks if the tree has constants" => _check_has_constants,
336- get_constants = " gets constants from the tree" => _check_get_constants,
337- set_constants! = " sets constants in the tree" => _check_set_constants!,
339+ get_constants = (" gets constants from the tree, returning a tuple of: " *
340+ " (1) a flat vector of the constants, and (2) a reference object that " *
341+ " can be used by `set_constants!` to efficiently set them back" ) => _check_get_constants,
342+ set_constants! = (" sets constants in the tree, given: " *
343+ " (1) a flat vector of constants, (2) the tree, and " *
344+ " (3) the reference object produced by `get_constants`" ) => _check_set_constants!,
338345 index_constants = " indexes constants in the tree" => _check_index_constants,
339- has_operators = " checks if the tree has operators" => _check_has_operators
346+ has_operators = " checks if the tree has operators" => _check_has_operators,
340347 )
341348)
342349
@@ -373,6 +380,8 @@ ni_description = (
373380
374381# ! format: on
375382
376- # TODO : Create an interface for evaluation
383+ # TODO : Create an interface for evaluation and `extract_gradient`
384+ # extract_gradient = ("given a Zygote-computed gradient with respect to the tree constants, " *
385+ # "extracts a flat vector in the same order as `get_constants`") => _check_extract_gradient,
377386
378387end
0 commit comments