Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ macro trace(args...)
track_numbers = true
checkpointing = false
mincut = false
tessera = false

expr = first(args)
while length(args) > 1
Expand All @@ -165,6 +166,9 @@ macro trace(args...)
mincut = val
end
args = args[2:end]
elseif args[1] === :tessera
tessera = true
args = args[2:end]
else
break
end
Expand All @@ -191,9 +195,10 @@ macro trace(args...)
end
)
)
return esc(trace_function_definition(__module__, expr))
return esc(trace_function_definition(__module__, expr; tessera))
end
#! format: on
@assert !tessera "tessera annotation is only allowed in front of function definitions"

if Meta.isexpr(expr, :(=))
if Meta.isexpr(expr.args[2], :if)
Expand Down Expand Up @@ -241,10 +246,12 @@ function get_argname(expr)
return expr, expr
end

function trace_function_definition(mod, expr)
function trace_function_definition(mod, expr; tessera=false)
internal_fn = MacroTools.splitdef(expr)
orig_fname = internal_fn[:name]

tessera_name = tessera ? String(orig_fname) : nothing

isfunctor = Meta.isexpr(orig_fname, :(::))
fname = gensym(Symbol(orig_fname, :internal))
internal_fn[:name] = fname
Expand All @@ -269,12 +276,18 @@ function trace_function_definition(mod, expr)
end

if isempty(new_fn[:kwargs])
traced_call_expr = :($(traced_call)($(fname), $(argnames...)))
traced_call_expr =
:($(traced_call)($(fname), $(argnames...); tessera_name=$(tessera_name)))
untraced_call_expr = :($(fname)($(argnames...)))
else
kws = first.(get_argname.(new_fn[:kwargs]))
traced_call_expr =
:($(traced_call)(Core.kwcall, (; $(kws...)), $(fname), $(argnames...)))
traced_call_expr = :($(traced_call)(
Core.kwcall,
(; $(kws...)),
$(fname),
$(argnames...);
tessera_name=$(tessera_name),
))
untraced_call_expr = :(Core.kwcall((; $(kws...)), $(fname), $(argnames...)))
end

Expand Down
4 changes: 2 additions & 2 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ function ReactantCore.traced_if(
return @opcall if_condition(cond, true_fn, false_fn, args...; track_numbers)
end

function ReactantCore.traced_call(f::Function, args...)
return @opcall call(f, args...)
function ReactantCore.traced_call(f::Function, args...; tessera_name=nothing)
return @opcall call(f, args...; tessera_name)
end

function ReactantCore.traced_while(
Expand Down
9 changes: 7 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2651,9 +2651,11 @@ end
return corrected_traced_results
end

@noinline function call(f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__))
@noinline function call(
f, args...; location=mlir_stacktrace("call", @__FILE__, @__LINE__), tessera_name=nothing
)
seen = Reactant.OrderedIdDict()
cache_key = []
cache_key = Any[tessera_name]
Reactant.make_tracer(seen, (f, args...), cache_key, Reactant.TracedToTypes)
cache = Reactant.Compiler.callcache()
if haskey(cache, cache_key)
Expand Down Expand Up @@ -2693,6 +2695,9 @@ end
resprefix,
resargprefix,
)
if !isnothing(tessera_name)
MLIR.IR.attr!(temp.f, "tessera_name", MLIR.IR.Attribute(tessera_name))
end
end

seen_cache = Reactant.OrderedIdDict()
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ end
@safetestset "Config" include("config.jl")
@safetestset "Batching" include("batching.jl")
@safetestset "QA" include("qa.jl")
@safetestset "Tessera" include("tessera.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
Expand Down
12 changes: 12 additions & 0 deletions test/tessera.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Reactant, Test

@trace tessera function foo(x)
return sin.(sum(x) .+ x)
end

@testset "Tessera Annotation Tests" begin
x = Reactant.to_rarray(rand(3))
# if optimize=false is not set, the function is inlined.
hlo = repr(@code_hlo optimize = false foo(x))
@test occursin("tessera_name = \"foo\"", hlo)
end
Loading