88import jax .debug as jdebug
99
1010from varipeps import varipeps_config , varipeps_global_state
11- from varipeps .peps import PEPS_Tensor , PEPS_Unit_Cell
11+ from varipeps .peps import PEPS_Tensor , PEPS_Tensor_Split_Transfer , PEPS_Unit_Cell
1212from varipeps .utils .debug_print import debug_print
13- from .absorption import do_absorption_step
13+ from .absorption import do_absorption_step , do_absorption_step_split_transfer
1414
1515from typing import Sequence , Tuple , List , Optional
1616
@@ -25,6 +25,14 @@ class CTM_Enum(enum.IntEnum):
2525 T2 = enum .auto ()
2626 T3 = enum .auto ()
2727 T4 = enum .auto ()
28+ T1_ket = enum .auto ()
29+ T1_bra = enum .auto ()
30+ T2_ket = enum .auto ()
31+ T2_bra = enum .auto ()
32+ T3_ket = enum .auto ()
33+ T3_bra = enum .auto ()
34+ T4_ket = enum .auto ()
35+ T4_bra = enum .auto ()
2836
2937
3038class CTMRGNotConvergedError (Exception ):
@@ -61,6 +69,8 @@ def _calc_corner_svds(
6169 C1_svd , indices_are_sorted = True , unique_indices = True
6270 )
6371
72+ # debug_print("C1: {}", C1_svd)
73+
6474 C2_svd = jnp .linalg .svd (t .C2 , full_matrices = False , compute_uv = False )
6575 step_corner_svd = step_corner_svd .at [ti , 1 , : C2_svd .shape [0 ]].set (
6676 C2_svd , indices_are_sorted = True , unique_indices = True
@@ -79,15 +89,20 @@ def _calc_corner_svds(
7989 return step_corner_svd
8090
8191
82- @partial (jit , static_argnums = (3 ,), inline = True )
92+ @partial (jit , static_argnums = (3 , 4 ), inline = True )
8393def _is_element_wise_converged (
8494 old_peps_tensors : List [PEPS_Tensor ],
8595 new_peps_tensors : List [PEPS_Tensor ],
8696 eps : float ,
8797 verbose : bool = False ,
98+ split_transfer : bool = False ,
8899) -> Tuple [bool , float , Optional [List [Tuple [int , CTM_Enum , float ]]]]:
89100 result = 0
90- measure = jnp .zeros ((len (old_peps_tensors ), 8 ), dtype = jnp .float64 )
101+
102+ if split_transfer :
103+ measure = jnp .zeros ((len (old_peps_tensors ), 12 ), dtype = jnp .float64 )
104+ else :
105+ measure = jnp .zeros ((len (old_peps_tensors ), 8 ), dtype = jnp .float64 )
91106
92107 verbose_data = [] if verbose else None
93108
@@ -144,73 +159,210 @@ def _is_element_wise_converged(
144159 if verbose :
145160 verbose_data .append ((ti , CTM_Enum .C4 , jnp .amax (diff )))
146161
147- old_shape = old_peps_tensors [ti ].T1 .shape
148- new_shape = new_peps_tensors [ti ].T1 .shape
149- diff = jnp .abs (
150- new_peps_tensors [ti ].T1 [
151- : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
152- ]
153- - old_peps_tensors [ti ].T1 [
154- : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
155- ]
156- )
157- result += jnp .sum (diff > eps )
158- measure = measure .at [ti , 4 ].set (
159- jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
160- )
161- if verbose :
162- verbose_data .append ((ti , CTM_Enum .T1 , jnp .amax (diff )))
163-
164- old_shape = old_peps_tensors [ti ].T2 .shape
165- new_shape = new_peps_tensors [ti ].T2 .shape
166- diff = jnp .abs (
167- new_peps_tensors [ti ].T2 [
168- : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
169- ]
170- - old_peps_tensors [ti ].T2 [
171- : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
172- ]
173- )
174- result += jnp .sum (diff > eps )
175- measure = measure .at [ti , 5 ].set (
176- jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
177- )
178- if verbose :
179- verbose_data .append ((ti , CTM_Enum .T2 , jnp .amax (diff )))
180-
181- old_shape = old_peps_tensors [ti ].T3 .shape
182- new_shape = new_peps_tensors [ti ].T3 .shape
183- diff = jnp .abs (
184- new_peps_tensors [ti ].T3 [
185- : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
186- ]
187- - old_peps_tensors [ti ].T3 [
188- : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
189- ]
190- )
191- result += jnp .sum (diff > eps )
192- measure = measure .at [ti , 6 ].set (
193- jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
194- )
195- if verbose :
196- verbose_data .append ((ti , CTM_Enum .T3 , jnp .amax (diff )))
197-
198- old_shape = old_peps_tensors [ti ].T4 .shape
199- new_shape = new_peps_tensors [ti ].T4 .shape
200- diff = jnp .abs (
201- new_peps_tensors [ti ].T4 [
202- : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
203- ]
204- - old_peps_tensors [ti ].T4 [
205- : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
206- ]
207- )
208- result += jnp .sum (diff > eps )
209- measure = measure .at [ti , 7 ].set (
210- jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
211- )
212- if verbose :
213- verbose_data .append ((ti , CTM_Enum .T4 , jnp .amax (diff )))
162+ if split_transfer :
163+ old_shape = old_peps_tensors [ti ].T1_ket .shape
164+ new_shape = new_peps_tensors [ti ].T1_ket .shape
165+ diff = jnp .abs (
166+ new_peps_tensors [ti ].T1_ket [
167+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
168+ ]
169+ - old_peps_tensors [ti ].T1_ket [
170+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
171+ ]
172+ )
173+ result += jnp .sum (diff > eps )
174+ measure = measure .at [ti , 4 ].set (
175+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
176+ )
177+ if verbose :
178+ verbose_data .append ((ti , CTM_Enum .T1_ket , jnp .amax (diff )))
179+
180+ old_shape = old_peps_tensors [ti ].T1_bra .shape
181+ new_shape = new_peps_tensors [ti ].T1_bra .shape
182+ diff = jnp .abs (
183+ new_peps_tensors [ti ].T1_bra [
184+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
185+ ]
186+ - old_peps_tensors [ti ].T1_bra [
187+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
188+ ]
189+ )
190+ result += jnp .sum (diff > eps )
191+ measure = measure .at [ti , 5 ].set (
192+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
193+ )
194+ if verbose :
195+ verbose_data .append ((ti , CTM_Enum .T1_bra , jnp .amax (diff )))
196+
197+ old_shape = old_peps_tensors [ti ].T2_ket .shape
198+ new_shape = new_peps_tensors [ti ].T2_ket .shape
199+ diff = jnp .abs (
200+ new_peps_tensors [ti ].T2_ket [
201+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
202+ ]
203+ - old_peps_tensors [ti ].T2_ket [
204+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
205+ ]
206+ )
207+ result += jnp .sum (diff > eps )
208+ measure = measure .at [ti , 6 ].set (
209+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
210+ )
211+ if verbose :
212+ verbose_data .append ((ti , CTM_Enum .T2_ket , jnp .amax (diff )))
213+
214+ old_shape = old_peps_tensors [ti ].T2_bra .shape
215+ new_shape = new_peps_tensors [ti ].T2_bra .shape
216+ diff = jnp .abs (
217+ new_peps_tensors [ti ].T2_bra [
218+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
219+ ]
220+ - old_peps_tensors [ti ].T2_bra [
221+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
222+ ]
223+ )
224+ result += jnp .sum (diff > eps )
225+ measure = measure .at [ti , 7 ].set (
226+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
227+ )
228+ if verbose :
229+ verbose_data .append ((ti , CTM_Enum .T2_bra , jnp .amax (diff )))
230+
231+ old_shape = old_peps_tensors [ti ].T3_ket .shape
232+ new_shape = new_peps_tensors [ti ].T3_ket .shape
233+ diff = jnp .abs (
234+ new_peps_tensors [ti ].T3_ket [
235+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
236+ ]
237+ - old_peps_tensors [ti ].T3_ket [
238+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
239+ ]
240+ )
241+ result += jnp .sum (diff > eps )
242+ measure = measure .at [ti , 8 ].set (
243+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
244+ )
245+ if verbose :
246+ verbose_data .append ((ti , CTM_Enum .T3_ket , jnp .amax (diff )))
247+
248+ old_shape = old_peps_tensors [ti ].T3_bra .shape
249+ new_shape = new_peps_tensors [ti ].T3_bra .shape
250+ diff = jnp .abs (
251+ new_peps_tensors [ti ].T3_bra [
252+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
253+ ]
254+ - old_peps_tensors [ti ].T3_bra [
255+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
256+ ]
257+ )
258+ result += jnp .sum (diff > eps )
259+ measure = measure .at [ti , 9 ].set (
260+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
261+ )
262+ if verbose :
263+ verbose_data .append ((ti , CTM_Enum .T3_bra , jnp .amax (diff )))
264+
265+ old_shape = old_peps_tensors [ti ].T4_ket .shape
266+ new_shape = new_peps_tensors [ti ].T4_ket .shape
267+ diff = jnp .abs (
268+ new_peps_tensors [ti ].T4_ket [
269+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
270+ ]
271+ - old_peps_tensors [ti ].T4_ket [
272+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
273+ ]
274+ )
275+ result += jnp .sum (diff > eps )
276+ measure = measure .at [ti , 10 ].set (
277+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
278+ )
279+ if verbose :
280+ verbose_data .append ((ti , CTM_Enum .T4_ket , jnp .amax (diff )))
281+
282+ old_shape = old_peps_tensors [ti ].T4_bra .shape
283+ new_shape = new_peps_tensors [ti ].T4_bra .shape
284+ diff = jnp .abs (
285+ new_peps_tensors [ti ].T4_bra [
286+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ]
287+ ]
288+ - old_peps_tensors [ti ].T4_bra [
289+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ]
290+ ]
291+ )
292+ result += jnp .sum (diff > eps )
293+ measure = measure .at [ti , 11 ].set (
294+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
295+ )
296+ if verbose :
297+ verbose_data .append ((ti , CTM_Enum .T4_bra , jnp .amax (diff )))
298+ else :
299+ old_shape = old_peps_tensors [ti ].T1 .shape
300+ new_shape = new_peps_tensors [ti ].T1 .shape
301+ diff = jnp .abs (
302+ new_peps_tensors [ti ].T1 [
303+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
304+ ]
305+ - old_peps_tensors [ti ].T1 [
306+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
307+ ]
308+ )
309+ result += jnp .sum (diff > eps )
310+ measure = measure .at [ti , 4 ].set (
311+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
312+ )
313+ if verbose :
314+ verbose_data .append ((ti , CTM_Enum .T1 , jnp .amax (diff )))
315+
316+ old_shape = old_peps_tensors [ti ].T2 .shape
317+ new_shape = new_peps_tensors [ti ].T2 .shape
318+ diff = jnp .abs (
319+ new_peps_tensors [ti ].T2 [
320+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
321+ ]
322+ - old_peps_tensors [ti ].T2 [
323+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
324+ ]
325+ )
326+ result += jnp .sum (diff > eps )
327+ measure = measure .at [ti , 5 ].set (
328+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
329+ )
330+ if verbose :
331+ verbose_data .append ((ti , CTM_Enum .T2 , jnp .amax (diff )))
332+
333+ old_shape = old_peps_tensors [ti ].T3 .shape
334+ new_shape = new_peps_tensors [ti ].T3 .shape
335+ diff = jnp .abs (
336+ new_peps_tensors [ti ].T3 [
337+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
338+ ]
339+ - old_peps_tensors [ti ].T3 [
340+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
341+ ]
342+ )
343+ result += jnp .sum (diff > eps )
344+ measure = measure .at [ti , 6 ].set (
345+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
346+ )
347+ if verbose :
348+ verbose_data .append ((ti , CTM_Enum .T3 , jnp .amax (diff )))
349+
350+ old_shape = old_peps_tensors [ti ].T4 .shape
351+ new_shape = new_peps_tensors [ti ].T4 .shape
352+ diff = jnp .abs (
353+ new_peps_tensors [ti ].T4 [
354+ : old_shape [0 ], : old_shape [1 ], : old_shape [2 ], : old_shape [3 ]
355+ ]
356+ - old_peps_tensors [ti ].T4 [
357+ : new_shape [0 ], : new_shape [1 ], : new_shape [2 ], : new_shape [3 ]
358+ ]
359+ )
360+ result += jnp .sum (diff > eps )
361+ measure = measure .at [ti , 7 ].set (
362+ jnp .linalg .norm (diff ), indices_are_sorted = True , unique_indices = True
363+ )
364+ if verbose :
365+ verbose_data .append ((ti , CTM_Enum .T4 , jnp .amax (diff )))
214366
215367 return result == 0 , jnp .linalg .norm (measure ), verbose_data
216368
@@ -230,16 +382,22 @@ def _ctmrg_body_func(carry):
230382 config ,
231383 ) = carry
232384
233- w_unitcell , norm_smallest_S = do_absorption_step (
234- w_tensors , w_unitcell_last_step , config , state
235- )
385+ if state .ctmrg_split_transfer :
386+ w_unitcell , norm_smallest_S = do_absorption_step_split_transfer (
387+ w_tensors , w_unitcell_last_step , config , state
388+ )
389+ else :
390+ w_unitcell , norm_smallest_S = do_absorption_step (
391+ w_tensors , w_unitcell_last_step , config , state
392+ )
236393
237394 def elementwise_func (old , new , old_corner , conv_eps , config ):
238395 converged , measure , verbose_data = _is_element_wise_converged (
239396 old ,
240397 new ,
241398 conv_eps ,
242399 verbose = config .ctmrg_verbose_output ,
400+ split_transfer = state .ctmrg_split_transfer ,
243401 )
244402 return converged , measure , verbose_data , old_corner
245403
@@ -377,12 +535,22 @@ def calc_ctmrg_env(
377535 norm_smallest_S = jnp .nan
378536 already_tried_chi = {working_unitcell [0 , 0 ][0 ][0 ].chi }
379537
538+ varipeps_global_state .ctmrg_split_transfer = isinstance (
539+ unitcell .get_unique_tensors ()[0 ], PEPS_Tensor_Split_Transfer
540+ )
541+
380542 while True :
381543 tmp_count = 0
382544 corner_singular_vals = None
383545
384546 while any (
385547 i .C1 .shape [0 ] != i .chi for i in working_unitcell .get_unique_tensors ()
548+ ) or (
549+ hasattr (working_unitcell .get_unique_tensors ()[0 ], "T4_ket" )
550+ and any (
551+ i .T4_ket .shape [0 ] != i .interlayer_chi
552+ for i in working_unitcell .get_unique_tensors ()
553+ )
386554 ):
387555 (
388556 _ ,
0 commit comments