From e8989a305cd120b180b5cb9b1a488bdb42b7bf26 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:16:06 -0600 Subject: [PATCH 1/4] add `tessera` option to `@trace` ``` @trace tessera mysin(x)=sin(x) ``` This will add an attribute with name `tessera_name` and value a stringattr with the function name (`"mysin"`) to the generated function for mysin. (note that you need `@code_hlo optimize=false` to see this attribute, otherwise it is inlined.) --- lib/ReactantCore/src/ReactantCore.jl | 18 ++++++++++++++---- src/ControlFlow.jl | 4 ++-- src/Ops.jl | 7 +++++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index d04bc9d9c6..c529ec04c3 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -147,6 +147,7 @@ macro trace(args...) track_numbers = true checkpointing = false mincut = false + tessera = false expr = first(args) while length(args) > 1 @@ -165,6 +166,9 @@ macro trace(args...) mincut = val end args = args[2:end] + elseif args[1] === :tessera + tessera = true + args = args[2:end] else break end @@ -191,9 +195,10 @@ macro trace(args...) end ) ) - return esc(trace_function_definition(__module__, expr)) + return esc(trace_function_definition(__module__, expr; tessera)) end #! format: on + @assert !tessera "tessera annotation is only allowed in front of function definitions" if Meta.isexpr(expr, :(=)) if Meta.isexpr(expr.args[2], :if) @@ -241,9 +246,11 @@ function get_argname(expr) return expr, expr end -function trace_function_definition(mod, expr) +function trace_function_definition(mod, expr; tessera=false) internal_fn = MacroTools.splitdef(expr) orig_fname = internal_fn[:name] + + tessera_name = tessera ? orig_fname : nothing isfunctor = Meta.isexpr(orig_fname, :(::)) fname = gensym(Symbol(orig_fname, :internal)) @@ -269,12 +276,12 @@ function trace_function_definition(mod, expr) end if isempty(new_fn[:kwargs]) - traced_call_expr = :($(traced_call)($(fname), $(argnames...))) + traced_call_expr = :($(traced_call)($(fname), $(argnames...); tessera_name=$(String(tessera_name)))) untraced_call_expr = :($(fname)($(argnames...))) else kws = first.(get_argname.(new_fn[:kwargs])) traced_call_expr = - :($(traced_call)(Core.kwcall, (; $(kws...)), $(fname), $(argnames...))) + :($(traced_call)(Core.kwcall, (; $(kws...)), $(fname), $(argnames...); tessera_name=$(String(tessera_name)))) untraced_call_expr = :(Core.kwcall((; $(kws...)), $(fname), $(argnames...))) end @@ -290,6 +297,9 @@ function trace_function_definition(mod, expr) return quote $(MacroTools.combinedef(new_fn)) $(MacroTools.combinedef(internal_fn)) + + # return the user-facing function: + $(new_fn[:name]) end end diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index f6d53381df..0dab89181e 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -4,8 +4,8 @@ function ReactantCore.traced_if( return @opcall if_condition(cond, true_fn, false_fn, args...; track_numbers) end -function ReactantCore.traced_call(f::Function, args...) - return @opcall call(f, args...) +function ReactantCore.traced_call(f::Function, args...; tessera_name=nothing) + return @opcall call(f, args...; tessera_name) end function ReactantCore.traced_while( diff --git a/src/Ops.jl b/src/Ops.jl index 75e4c659cb..be96b87045 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2651,9 +2651,9 @@ end return corrected_traced_results end -@noinline function call(f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__)) +@noinline function call(f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__), tessera_name=nothing) seen = Reactant.OrderedIdDict() - cache_key = [] + cache_key = Any[tessera_name] Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) cache = Reactant.Compiler.callcache() if haskey(cache, cache_key) @@ -2693,6 +2693,9 @@ end resprefix, resargprefix, ) + if !isnothing(tessera_name) + MLIR.IR.attr!(temp.f, "tessera_name", MLIR.IR.Attribute(tessera_name)) + end end seen_cache = Reactant.OrderedIdDict() From 59c67e6d07c5f12c350c05b770ccfad1519c6c88 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:24:54 -0600 Subject: [PATCH 2/4] test --- test/runtests.jl | 1 + test/tessera.jl | 12 ++++++++++++ 2 files changed, 13 insertions(+) create mode 100644 test/tessera.jl diff --git a/test/runtests.jl b/test/runtests.jl index bbd5e0855f..3c073a4fee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,7 @@ end @safetestset "Config" include("config.jl") @safetestset "Batching" include("batching.jl") @safetestset "QA" include("qa.jl") + @safetestset "Tessera" include("tessera.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" diff --git a/test/tessera.jl b/test/tessera.jl new file mode 100644 index 0000000000..0a58f8ef6f --- /dev/null +++ b/test/tessera.jl @@ -0,0 +1,12 @@ +using Reactant, Test + +@trace tessera function foo(x) + return sin.(sum(x) .+ x) +end + +@testset "Tessera Annotation Tests" begin + x = Reactant.to_rarray(rand(3)) + # if optimize=false is not set, the function is inlined. + hlo = repr(@code_hlo optimize = false foo(x)) + @test occursin("tessera_name = \"foo\"", hlo) +end From e389362718910cbbdf155d2f82276096e499f282 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:37:54 -0600 Subject: [PATCH 3/4] formatting --- lib/ReactantCore/src/ReactantCore.jl | 19 +++++++++++++------ src/Ops.jl | 6 ++++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index c529ec04c3..d5ac1cd591 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -167,8 +167,8 @@ macro trace(args...) end args = args[2:end] elseif args[1] === :tessera - tessera = true - args = args[2:end] + tessera = true + args = args[2:end] else break end @@ -249,7 +249,7 @@ end function trace_function_definition(mod, expr; tessera=false) internal_fn = MacroTools.splitdef(expr) orig_fname = internal_fn[:name] - + tessera_name = tessera ? orig_fname : nothing isfunctor = Meta.isexpr(orig_fname, :(::)) @@ -276,12 +276,19 @@ function trace_function_definition(mod, expr; tessera=false) end if isempty(new_fn[:kwargs]) - traced_call_expr = :($(traced_call)($(fname), $(argnames...); tessera_name=$(String(tessera_name)))) + traced_call_expr = :($(traced_call)( + $(fname), $(argnames...); tessera_name=$(String(tessera_name)) + )) untraced_call_expr = :($(fname)($(argnames...))) else kws = first.(get_argname.(new_fn[:kwargs])) - traced_call_expr = - :($(traced_call)(Core.kwcall, (; $(kws...)), $(fname), $(argnames...); tessera_name=$(String(tessera_name)))) + traced_call_expr = :($(traced_call)( + Core.kwcall, + (; $(kws...)), + $(fname), + $(argnames...); + tessera_name=$(String(tessera_name)), + )) untraced_call_expr = :(Core.kwcall((; $(kws...)), $(fname), $(argnames...))) end diff --git a/src/Ops.jl b/src/Ops.jl index be96b87045..d36c0d22e3 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2651,7 +2651,9 @@ end return corrected_traced_results end -@noinline function call(f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__), tessera_name=nothing) +@noinline function call( + f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__), tessera_name=nothing +) seen = Reactant.OrderedIdDict() cache_key = Any[tessera_name] Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) @@ -2694,7 +2696,7 @@ end resargprefix, ) if !isnothing(tessera_name) - MLIR.IR.attr!(temp.f, "tessera_name", MLIR.IR.Attribute(tessera_name)) + MLIR.IR.attr!(temp.f, "tessera_name", MLIR.IR.Attribute(tessera_name)) end end From 6804c974033d66691fd0b586b75b130dbba12d32 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 16 Dec 2025 16:39:59 -0600 Subject: [PATCH 4/4] fix --- lib/ReactantCore/src/ReactantCore.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index d5ac1cd591..d96f5713ea 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -250,7 +250,7 @@ function trace_function_definition(mod, expr; tessera=false) internal_fn = MacroTools.splitdef(expr) orig_fname = internal_fn[:name] - tessera_name = tessera ? orig_fname : nothing + tessera_name = tessera ? String(orig_fname) : nothing isfunctor = Meta.isexpr(orig_fname, :(::)) fname = gensym(Symbol(orig_fname, :internal)) @@ -276,9 +276,8 @@ function trace_function_definition(mod, expr; tessera=false) end if isempty(new_fn[:kwargs]) - traced_call_expr = :($(traced_call)( - $(fname), $(argnames...); tessera_name=$(String(tessera_name)) - )) + traced_call_expr = + :($(traced_call)($(fname), $(argnames...); tessera_name=$(tessera_name))) untraced_call_expr = :($(fname)($(argnames...))) else kws = first.(get_argname.(new_fn[:kwargs])) @@ -287,7 +286,7 @@ function trace_function_definition(mod, expr; tessera=false) (; $(kws...)), $(fname), $(argnames...); - tessera_name=$(String(tessera_name)), + tessera_name=$(tessera_name), )) untraced_call_expr = :(Core.kwcall((; $(kws...)), $(fname), $(argnames...))) end @@ -304,9 +303,6 @@ function trace_function_definition(mod, expr; tessera=false) return quote $(MacroTools.combinedef(new_fn)) $(MacroTools.combinedef(internal_fn)) - - # return the user-facing function: - $(new_fn[:name]) end end