Struggling with objax, Modules, StateVar, and python classes #262
Replies: 3 comments
-
|
Weirdly, I can replace the |
Beta Was this translation helpful? Give feedback.
-
|
Even weirder, I've tracked it down to line 67 in
With |
Beta Was this translation helpful? Give feedback.
-
|
Definitely some caching/closure issue, if I create a module level function like so: def get_train_op(model):
opt_model = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())
def train_op(_s, _t):
dE, E = energy(_s, _t)
opt_model(0.1, dE)
return objax.Jit(train_op, model.vars() + opt_model.vars())and then use it to create my self.ops[i] = get_train_op(self.objs[i])then everything works as expected, I can use the I literally traced the moment it changes the wrong |
Beta Was this translation helpful? Give feedback.


Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am having trouble understanding how to use objax Modules with python classes, and manually updating StateVars. In the code below, basically I want to have a container holding
numTestobjects, each of which has their owntrain_opand Adam optimizer. It's a silly example and not sure the maths makes sense, but it seems to highlight the issue.If I run the code with
num = 1below, it Jit compiles and runs fine. As soon as I setnum = 2I get anUnexpectedTracerErrorwhich points to theopt_model(0.1, dE)line. The documentation seems to say thatStateVars are used for manually tracking variables, in my case I have a set of manually updated variables and Adam updated variables that are updated using the same loss function. What am I doing wrong here? (Thefunctools.partialseems to break things even withnum = 1).Beta Was this translation helpful? Give feedback.
All reactions