@@ -69,8 +69,6 @@ def _calc_corner_svds(
6969 C1_svd , indices_are_sorted = True , unique_indices = True
7070 )
7171
72- # debug_print("C1: {}", C1_svd)
73-
7472 C2_svd = jnp .linalg .svd (t .C2 , full_matrices = False , compute_uv = False )
7573 step_corner_svd = step_corner_svd .at [ti , 1 , : C2_svd .shape [0 ]].set (
7674 C2_svd , indices_are_sorted = True , unique_indices = True
@@ -382,7 +380,7 @@ def _ctmrg_body_func(carry):
382380 config ,
383381 ) = carry
384382
385- if state . ctmrg_split_transfer :
383+ if w_unitcell_last_step . is_split_transfer () :
386384 w_unitcell , norm_smallest_S = do_absorption_step_split_transfer (
387385 w_tensors , w_unitcell_last_step , config , state
388386 )
@@ -397,7 +395,7 @@ def elementwise_func(old, new, old_corner, conv_eps, config):
397395 new ,
398396 conv_eps ,
399397 verbose = config .ctmrg_verbose_output ,
400- split_transfer = state . ctmrg_split_transfer ,
398+ split_transfer = w_unitcell . is_split_transfer () ,
401399 )
402400 return converged , measure , verbose_data , old_corner
403401
@@ -535,10 +533,6 @@ def calc_ctmrg_env(
535533 norm_smallest_S = jnp .nan
536534 already_tried_chi = {working_unitcell [0 , 0 ][0 ][0 ].chi }
537535
538- varipeps_global_state .ctmrg_split_transfer = isinstance (
539- unitcell .get_unique_tensors ()[0 ], PEPS_Tensor_Split_Transfer
540- )
541-
542536 while True :
543537 tmp_count = 0
544538 corner_singular_vals = None
@@ -776,6 +770,7 @@ def _ctmrg_rev_while_body(carry):
776770 bar_fixed_point .get_unique_tensors (),
777771 config .ad_custom_convergence_eps ,
778772 verbose = config .ad_custom_verbose_output ,
773+ split_transfer = bar_fixed_point .is_split_transfer (),
779774 )
780775
781776 count += 1
@@ -796,15 +791,31 @@ def _ctmrg_rev_while_body(carry):
796791
797792@jit
798793def _ctmrg_rev_workhorse (peps_tensors , new_unitcell , new_unitcell_bar , config , state ):
799- _ , vjp_peps_tensors = vjp (
800- lambda t : do_absorption_step (t , new_unitcell , config , state ), peps_tensors
801- )
794+ if new_unitcell .is_split_transfer ():
795+ _ , vjp_peps_tensors = vjp (
796+ lambda t : do_absorption_step_split_transfer (t , new_unitcell , config , state ),
797+ peps_tensors ,
798+ )
802799
803- vjp_env = tree_util .Partial (
804- vjp (lambda u : do_absorption_step (peps_tensors , u , config , state ), new_unitcell )[
805- 1
806- ]
807- )
800+ vjp_env = tree_util .Partial (
801+ vjp (
802+ lambda u : do_absorption_step_split_transfer (
803+ peps_tensors , u , config , state
804+ ),
805+ new_unitcell ,
806+ )[1 ]
807+ )
808+ else :
809+ _ , vjp_peps_tensors = vjp (
810+ lambda t : do_absorption_step (t , new_unitcell , config , state ), peps_tensors
811+ )
812+
813+ vjp_env = tree_util .Partial (
814+ vjp (
815+ lambda u : do_absorption_step (peps_tensors , u , config , state ),
816+ new_unitcell ,
817+ )[1 ]
818+ )
808819
809820 def cond_func (carry ):
810821 _ , _ , _ , converged , count , config , state = carry
0 commit comments