diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index d04bc9d9c6..d96f5713ea 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -147,6 +147,7 @@ macro trace(args...) track_numbers = true checkpointing = false mincut = false + tessera = false expr = first(args) while length(args) > 1 @@ -165,6 +166,9 @@ macro trace(args...) mincut = val end args = args[2:end] + elseif args[1] === :tessera + tessera = true + args = args[2:end] else break end @@ -191,9 +195,10 @@ macro trace(args...) end ) ) - return esc(trace_function_definition(__module__, expr)) + return esc(trace_function_definition(__module__, expr; tessera)) end #! format: on + @assert !tessera "tessera annotation is only allowed in front of function definitions" if Meta.isexpr(expr, :(=)) if Meta.isexpr(expr.args[2], :if) @@ -241,10 +246,12 @@ function get_argname(expr) return expr, expr end -function trace_function_definition(mod, expr) +function trace_function_definition(mod, expr; tessera=false) internal_fn = MacroTools.splitdef(expr) orig_fname = internal_fn[:name] + tessera_name = tessera ? String(orig_fname) : nothing + isfunctor = Meta.isexpr(orig_fname, :(::)) fname = gensym(Symbol(orig_fname, :internal)) internal_fn[:name] = fname @@ -269,12 +276,18 @@ function trace_function_definition(mod, expr) end if isempty(new_fn[:kwargs]) - traced_call_expr = :($(traced_call)($(fname), $(argnames...))) + traced_call_expr = + :($(traced_call)($(fname), $(argnames...); tessera_name=$(tessera_name))) untraced_call_expr = :($(fname)($(argnames...))) else kws = first.(get_argname.(new_fn[:kwargs])) - traced_call_expr = - :($(traced_call)(Core.kwcall, (; $(kws...)), $(fname), $(argnames...))) + traced_call_expr = :($(traced_call)( + Core.kwcall, + (; $(kws...)), + $(fname), + $(argnames...); + tessera_name=$(tessera_name), + )) untraced_call_expr = :(Core.kwcall((; $(kws...)), $(fname), $(argnames...))) end diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index f6d53381df..0dab89181e 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -4,8 +4,8 @@ function ReactantCore.traced_if( return @opcall if_condition(cond, true_fn, false_fn, args...; track_numbers) end -function ReactantCore.traced_call(f::Function, args...) - return @opcall call(f, args...) +function ReactantCore.traced_call(f::Function, args...; tessera_name=nothing) + return @opcall call(f, args...; tessera_name) end function ReactantCore.traced_while( diff --git a/src/Ops.jl b/src/Ops.jl index 75e4c659cb..d36c0d22e3 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -2651,9 +2651,11 @@ end return corrected_traced_results end -@noinline function call(f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__)) +@noinline function call( + f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__), tessera_name=nothing +) seen = Reactant.OrderedIdDict() - cache_key = [] + cache_key = Any[tessera_name] Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes) cache = Reactant.Compiler.callcache() if haskey(cache, cache_key) @@ -2693,6 +2695,9 @@ end resprefix, resargprefix, ) + if !isnothing(tessera_name) + MLIR.IR.attr!(temp.f, "tessera_name", MLIR.IR.Attribute(tessera_name)) + end end seen_cache = Reactant.OrderedIdDict() diff --git a/test/runtests.jl b/test/runtests.jl index bbd5e0855f..3c073a4fee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,7 @@ end @safetestset "Config" include("config.jl") @safetestset "Batching" include("batching.jl") @safetestset "QA" include("qa.jl") + @safetestset "Tessera" include("tessera.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" diff --git a/test/tessera.jl b/test/tessera.jl new file mode 100644 index 0000000000..0a58f8ef6f --- /dev/null +++ b/test/tessera.jl @@ -0,0 +1,12 @@ +using Reactant, Test + +@trace tessera function foo(x) + return sin.(sum(x) .+ x) +end + +@testset "Tessera Annotation Tests" begin + x = Reactant.to_rarray(rand(3)) + # if optimize=false is not set, the function is inlined. + hlo = repr(@code_hlo optimize = false foo(x)) + @test occursin("tessera_name = \"foo\"", hlo) +end