Skip to content

Commit ce84be1

Browse files
committed
fix: some instabilities in base
1 parent 23a8bbe commit ce84be1

File tree

6 files changed

+30
-29
lines changed

6 files changed

+30
-29
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module DynamicExpressions
22

33
using DispatchDoctor: @stable, @unstable
44

5-
@stable default_mode="disable" begin
5+
@stable default_mode = "disable" begin
66
include("Utils.jl")
77
include("ExtensionInterface.jl")
88
include("OperatorEnum.jl")

src/OperatorEnumConstruction.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function Base.show(io::IO, tree::AbstractExpressionNode)
4545
return print(io, string_tree(tree, latest_operators; kwargs...))
4646
end
4747
end
48-
function (tree::AbstractExpressionNode)(X; kws...)
48+
@unstable function (tree::AbstractExpressionNode)(X; kws...)
4949
Base.depwarn(
5050
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
5151
:AbstractExpressionNode,
@@ -64,7 +64,7 @@ function (tree::AbstractExpressionNode)(X; kws...)
6464
end
6565
end
6666

67-
function _grad_evaluator(tree::AbstractExpressionNode, X; kws...)
67+
@unstable function _grad_evaluator(tree::AbstractExpressionNode, X; kws...)
6868
Base.depwarn(
6969
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
7070
:AbstractExpressionNode,
@@ -95,7 +95,7 @@ function set_default_operators!(operators::GenericOperatorEnum)
9595
return LATEST_OPERATORS_TYPE.x = IsGenericOperatorEnum
9696
end
9797

98-
function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
98+
@unstable function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
9999
mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
100100
if !haskey(mapping, f)
101101
error(
@@ -420,7 +420,7 @@ and `(::AbstractExpressionNode)(X)`.
420420
are *not* needed for the package to work; they are purely for convenience.
421421
- `empty_old_operators::Bool=true`: Whether to clear the old operators.
422422
"""
423-
function GenericOperatorEnum(;
423+
@unstable function GenericOperatorEnum(;
424424
binary_operators=Function[],
425425
unary_operators=Function[],
426426
define_helper_functions::Bool=true,

src/Simplify.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,9 @@ end
122122

123123
# Simplify tree
124124
function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum)
125-
tree = tree_mapreduce(
126-
identity,
127-
(p, c...) -> combine_children!(operators, p, c...),
128-
tree,
129-
constructorof(typeof(tree));
125+
return tree_mapreduce(
126+
identity, (p, c...) -> combine_children!(operators, p, c...), tree, typeof(tree);
130127
)
131-
return tree
132128
end
133129

134130
end

src/base.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using ..UtilsModule: @memoize_on, @with_memoize, Undefined
3333
[f_branch::Function,]
3434
op::Function,
3535
tree::AbstractNode,
36+
[result_type::Type=Undefined];
3637
f_on_shared::Function=(result, is_shared) -> result,
3738
break_sharing::Val=Val(false),
3839
)
@@ -78,11 +79,11 @@ function tree_mapreduce(
7879
f::F,
7980
op::G,
8081
tree::AbstractNode,
81-
result_type::Type=Undefined;
82+
result_type::Type{RT}=Undefined;
8283
f_on_shared::H=(result, is_shared) -> result,
8384
break_sharing=Val(false),
84-
) where {F<:Function,G<:Function,H<:Function}
85-
return tree_mapreduce(f, f, op, tree, result_type; f_on_shared, break_sharing)
85+
) where {RT,F<:Function,G<:Function,H<:Function}
86+
return tree_mapreduce(f, f, op, tree, RT; f_on_shared, break_sharing)
8687
end
8788
function tree_mapreduce(
8889
f_leaf::F1,
@@ -341,7 +342,7 @@ function count(
341342
end
342343

343344
"""
344-
sum(f::Function, tree::AbstractNode; init=0, return_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}
345+
sum(f::Function, tree::AbstractNode; return_type=Undefined, f_on_shared=_default_shared_aggregation, break_sharing::Val=Val(false)) where {F<:Function}
345346
346347
Sum the results of a function over a tree. For graphs with shared nodes
347348
such as `GraphNode`, the function `f_on_shared` is called on the result
@@ -351,15 +352,14 @@ behavior).
351352
function sum(
352353
f::F,
353354
tree::AbstractNode;
354-
init=0,
355355
return_type=Undefined,
356356
f_on_shared=_default_shared_aggregation,
357357
break_sharing::Val=Val(false),
358358
) where {F<:Function}
359359
if preserve_sharing(typeof(tree))
360360
@assert typeof(return_type) !== Undefined "Must specify `return_type` as a keyword argument to `sum` if `preserve_sharing` is true."
361361
end
362-
return tree_mapreduce(f, +, tree, return_type; f_on_shared, break_sharing) + init
362+
return tree_mapreduce(f, +, tree, return_type; f_on_shared, break_sharing)
363363
end
364364
function _default_shared_aggregation(c, is_shared)
365365
return is_shared ? (false * c) : c
@@ -382,15 +382,15 @@ function mapreduce(
382382
f::F,
383383
op::G,
384384
tree::AbstractNode;
385-
return_type=Undefined,
385+
return_type::Type{T}=Undefined,
386386
f_on_shared=(c, is_shared) -> is_shared ? (false * c) : c,
387387
break_sharing::Val=Val(false),
388-
) where {F<:Function,G<:Function}
389-
if preserve_sharing(typeof(tree))
390-
@assert typeof(return_type) !== Undefined "Must specify `return_type` as a keyword argument to `mapreduce` if `preserve_sharing` is true."
388+
) where {T,F<:Function,G<:Function}
389+
if preserve_sharing(typeof(tree)) && break_sharing === Val(false)
390+
@assert T !== Undefined "Must specify `return_type` as a keyword argument to `mapreduce` if `preserve_sharing` is true."
391391
end
392392
return tree_mapreduce(
393-
f, (n...) -> reduce(op, n), tree, return_type; f_on_shared, break_sharing
393+
f, (p, c...) -> reduce(op, (p, c...)), tree, T; f_on_shared, break_sharing
394394
)
395395
end
396396

test/test_deprecations.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using DynamicExpressions
22
using Test
33
using Zygote
44
using Suppressor: @capture_err
5+
using DispatchDoctor: allow_unstable
56

67
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
78
x1, x2 = Node{Float64}(; feature=1), Node{Float64}(; feature=2)
@@ -16,10 +17,11 @@ for constructor in (OperatorEnum, GenericOperatorEnum)
1617

1718
constructor == GenericOperatorEnum && continue
1819

19-
VERSION >= v"1.9" &&
20-
@test_logs (:warn, r"The `tree'\(X; kws...\)` syntax is deprecated.*") tree'(
21-
[1.0; 2.0;;]
22-
)
20+
if VERSION >= v"1.9"
21+
@test_logs (:warn, r"The `tree'\(X; kws...\)` syntax is deprecated.*") allow_unstable() do
22+
tree'([1.0; 2.0;;])
23+
end
24+
end
2325
end
2426

2527
if VERSION >= v"1.9"

test/test_initial_errors.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@ if VERSION >= v"1.9"
3333
symbolic_to_node(tree, operators)
3434
)
3535

36-
@test_throws("Please load the Zygote.jl package.", tree'(ones(2, 10)))
36+
@test_throws(
37+
"Please load the Zygote.jl package.", allow_unstable(() -> tree'(ones(2, 10)))
38+
)
3739

3840
@test_throws(
39-
"Please load the Bumper.jl package", tree(ones(2, 10), operators; bumper=Val(true))
41+
"Please load the Bumper.jl package",
42+
allow_unstable(() -> tree(ones(2, 10), operators; bumper=Val(true)))
4043
)
4144

4245
@test_throws(
4346
"Please load the LoopVectorization.jl package",
44-
tree(ones(2, 10), operators; turbo=Val(true))
47+
allow_unstable(() -> tree(ones(2, 10), operators; turbo=Val(true)))
4548
)
4649
end

0 commit comments

Comments
 (0)