Skip to content

Commit 429a7cd

Browse files
committed
rename
1 parent 0172a44 commit 429a7cd

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
using ..Reactant: TracedRArray
22

3-
function hmc(
3+
function mcmc(
44
rng::AbstractRNG,
55
original_trace,
66
f::Function,
77
args::Vararg{Any,Nargs};
88
selection::Selection,
9+
algorithm::Symbol=:HMC,
910
inverse_mass_matrix=nothing,
1011
step_size=nothing,
1112
num_steps=nothing,
@@ -48,9 +49,19 @@ function hmc(
4849
)::MLIR.IR.Type
4950
accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool))
5051

52+
# Map algorithm symbol to integer enum value
53+
# From EnzymeOps.td: HMC = 0, NUTS = 1
54+
alg_value = if algorithm == :HMC
55+
Int32(0)
56+
elseif algorithm == :NUTS
57+
Int32(1)
58+
else
59+
error("Unknown MCMC algorithm: $algorithm. Supported algorithms are :HMC and :NUTS")
60+
end
61+
5162
alg_attr = @ccall MLIR.API.mlir_c.enzymeMCMCAlgorithmAttrGet(
5263
MLIR.IR.context()::MLIR.API.MlirContext,
53-
0::Int32, # 0 = HMC
64+
alg_value::Int32,
5465
)::MLIR.IR.Attribute
5566

5667
inverse_mass_matrix_val = nothing

src/probprog/ProbProg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ include("FFI.jl")
99
include("Modeling.jl")
1010
include("Display.jl")
1111
include("MH.jl")
12-
include("HMC.jl")
12+
include("MCMC.jl")
1313

1414
# Types.
1515
export ProbProgTrace, Constraint, Selection, Address
@@ -18,7 +18,7 @@ export ProbProgTrace, Constraint, Selection, Address
1818
export get_choices, select
1919

2020
# Core MLIR ops.
21-
export sample, untraced_call, simulate, generate, mh, hmc
21+
export sample, untraced_call, simulate, generate, mh, mcmc
2222

2323
# Gen-like helper functions.
2424
export simulate_, generate_

0 commit comments

Comments
 (0)