Skip to content

Commit 4cbbf86

Browse files
committed
docs: better explanation of interface
1 parent 9963ef5 commit 4cbbf86

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/Interfaces.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,16 @@ ei_components = (
153153
index_constants = "indexes constants in the expression tree" => _check_index_constants,
154154
has_operators = "checks if the expression has operators" => _check_has_operators,
155155
has_constants = "checks if the expression has constants" => _check_has_constants,
156-
get_constants = "gets constants from the expression tree" => _check_get_constants,
157-
set_constants! = "sets constants in the expression tree" => _check_set_constants!,
156+
get_constants = ("gets constants from the expression tree, returning a tuple of: " *
157+
"(1) a flat vector of the constants, and (2) an reference object that " *
158+
"can be used by `set_constants!` to efficiently set them back") => _check_get_constants,
159+
set_constants! = ("sets constants in the expression tree, given: " *
160+
"(1) a flat vector of constants, (2) the expression, and " *
161+
"(3) the reference object produced by `get_constants`") => _check_set_constants!,
158162
string_tree = "returns a string representation of the expression tree" => _check_string_tree,
159163
default_node_type = "returns the default node type for the expression" => _check_default_node,
160164
constructorof = "gets the constructor function for a type" => _check_constructorof,
161-
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce
162-
# TODO: add extract_gradient(gradient, ex::AbstractExpression)
165+
tree_mapreduce = "applies a function across the tree" => _check_tree_mapreduce,
163166
)
164167
)
165168
ei_description = (
@@ -333,10 +336,14 @@ ni_components = (
333336
count_constants = "counts the number of constants" => _check_count_constants,
334337
filter_map = "applies a filter and map function to the tree" => _check_filter_map,
335338
has_constants = "checks if the tree has constants" => _check_has_constants,
336-
get_constants = "gets constants from the tree" => _check_get_constants,
337-
set_constants! = "sets constants in the tree" => _check_set_constants!,
339+
get_constants = ("gets constants from the tree, returning a tuple of: " *
340+
"(1) a flat vector of the constants, and (2) a reference object that " *
341+
"can be used by `set_constants!` to efficiently set them back") => _check_get_constants,
342+
set_constants! = ("sets constants in the tree, given: " *
343+
"(1) a flat vector of constants, (2) the tree, and " *
344+
"(3) the reference object produced by `get_constants`") => _check_set_constants!,
338345
index_constants = "indexes constants in the tree" => _check_index_constants,
339-
has_operators = "checks if the tree has operators" => _check_has_operators
346+
has_operators = "checks if the tree has operators" => _check_has_operators,
340347
)
341348
)
342349

@@ -373,6 +380,8 @@ ni_description = (
373380

374381
#! format: on
375382

376-
# TODO: Create an interface for evaluation
383+
# TODO: Create an interface for evaluation and `extract_gradient`
384+
# extract_gradient = ("given a Zygote-computed gradient with respect to the tree constants, " *
385+
# "extracts a flat vector in the same order as `get_constants`") => _check_extract_gradient,
377386

378387
end

0 commit comments

Comments
 (0)