2222
2323import asyncio
2424from concurrent .futures import ThreadPoolExecutor
25- from typing import Awaitable , Callable , NamedTuple , Set , Union , List
25+ from collections import defaultdict
26+ from typing import (
27+ Optional , Union ,
28+ Awaitable , Callable ,
29+ NamedTuple , List , Dict , Set ,
30+ )
2631
2732from .encoding import to_bytes
2833
3641
3742class Pass (NamedTuple ):
3843 block_index : int
39- index : int
40- byte : int
44+ solved : List [int ]
4145
4246
4347class Fail (NamedTuple ):
@@ -46,13 +50,7 @@ class Fail(NamedTuple):
4650 is_critical : bool = False
4751
4852
49- class Done (NamedTuple ):
50- block_index : int
51- C0 : List [int ]
52- X1 : List [int ]
53-
54-
55- ResultType = Union [Pass , Fail , Done ]
53+ ResultType = Union [Pass , Fail ]
5654
5755OracleFunc = Callable [[bytes ], bool ]
5856ResultCallback = Callable [[ResultType ], bool ]
@@ -68,7 +66,7 @@ class Context(NamedTuple):
6866
6967 tasks : Set [Awaitable [ResultType ]]
7068
71- latest_plaintext : List [ int ]
69+ solved_counts : Dict [ int , int ]
7270 plaintext : List [int ]
7371
7472 result_callback : ResultCallback
@@ -122,10 +120,10 @@ async def solve_async(ciphertext: bytes,
122120 ctx .tasks .remove (task )
123121
124122 if isinstance (result , Pass ):
125- update_latest_plaintext (
126- ctx , result .block_index , result .index , result . byte )
127- if isinstance ( result , Done ):
128- update_plaintext ( ctx , result . block_index , result . C0 , result . X1 )
123+ if len ( result . solved ) >= ctx . solved_counts [ result . block_index ]:
124+ update_plaintext ( ctx , result .block_index , result .solved )
125+ ctx . solved_counts [ result . block_index ] = len ( result . solved )
126+ ctx . plaintext_callback ( ctx . plaintext )
129127
130128 if len (ctx .tasks ) == 0 :
131129 break
@@ -151,59 +149,71 @@ def create_solve_context(ciphertext, block_size, oracle, parallel,
151149 for i in range (0 , len (ciphertext ), block_size ):
152150 cipher_blocks .append (ciphertext [i :i + block_size ])
153151
152+ solved_counts = defaultdict (lambda : 0 )
153+
154154 plaintext = [None ] * (len (cipher_blocks ) - 1 ) * block_size
155- latest_plaintext = plaintext .copy ()
156155
157156 executor = ThreadPoolExecutor (parallel )
158157 loop = asyncio .get_event_loop ()
159158 ctx = Context (block_size , oracle , executor , loop , tasks ,
160- latest_plaintext , plaintext ,
159+ solved_counts , plaintext ,
161160 result_callback , plaintext_callback )
162161
163162 for i in range (1 , len (cipher_blocks )):
164- run_block_task (ctx , i , cipher_blocks [i - 1 ], cipher_blocks [i ], [])
163+ add_solve_block_task (ctx , i , cipher_blocks [i - 1 ], cipher_blocks [i ], [])
165164
166165 return ctx
167166
168167
169- def run_block_task (ctx : Context , block_index , C0 , C1 , X1 ):
170- future = solve_block (ctx , block_index , C0 , C1 , X1 )
168+ def add_solve_block_task (ctx : Context , block_index : int , C0 : List [int ],
169+ C1 : List [int ], X1_suffix : List [int ]):
170+ future = solve_block (ctx , block_index , C0 , C1 , X1_suffix )
171171 task = ctx .loop .create_task (future )
172172 ctx .tasks .add (task )
173173
174174
175175async def solve_block (ctx : Context , block_index : int , C0 : List [int ],
176- C1 : List [int ], X1 : List [int ] = []) -> ResultType :
176+ C1 : List [int ], X1_suffix : List [int ] = []) -> ResultType :
177+
178+ assert len (C0 ) == ctx .block_size
179+ assert len (C1 ) == ctx .block_size
180+ assert len (X1_suffix ) in range (ctx .block_size + 1 )
181+
177182 # X1 = decrypt(C1)
178183 # P1 = xor(C0, X1)
184+ C0_suffix = C0 [len (C0 )- len (X1_suffix ):]
185+ P1_suffix = [c ^ x for c , x in zip (C0_suffix , X1_suffix )]
179186
180- if len (X1 ) == ctx .block_size :
181- return Done (block_index , C0 , X1 )
187+ if len (P1_suffix ) < ctx .block_size :
188+ result = await exploit_oracle (ctx , block_index , C0 , C1 , X1_suffix )
189+ if isinstance (result , Fail ):
190+ return result
182191
183- assert len (C0 ) == ctx .block_size
184- assert len (C1 ) == ctx .block_size
185- assert len (X1 ) in range (ctx .block_size )
192+ return Pass (block_index , P1_suffix )
186193
187- index = ctx .block_size - len (X1 ) - 1
188- padding = len (X1 ) + 1
194+
195+ async def exploit_oracle (ctx : Context , block_index : int ,
196+ C0 : List [int ], C1 : List [int ],
197+ X1_suffix : List [int ]) -> Optional [Fail ]:
198+ index = ctx .block_size - len (X1_suffix ) - 1
199+ padding = len (X1_suffix ) + 1
189200
190201 C0_test = C0 .copy ()
191- for i in range (len (X1 )):
192- C0_test [- i - 1 ] = X1 [- i - 1 ] ^ padding
202+ for i in range (len (X1_suffix )):
203+ C0_test [- i - 1 ] = X1_suffix [- i - 1 ] ^ padding
193204 hits = list (await get_oracle_hits (ctx , C0_test , C1 , index ))
194205
195- invalid = len (X1 ) == 0 and len (hits ) not in (1 , 2 )
196- invalid |= len (X1 ) > 0 and len (hits ) != 1
206+ # Check if the number of hits is invalid
207+ invalid = len (X1_suffix ) == 0 and len (hits ) not in (1 , 2 )
208+ invalid |= len (X1_suffix ) > 0 and len (hits ) != 1
197209 if invalid :
198- message = 'unexpected number of hits: block={} index={} n={}' \
199- . format ( block_index , index , len ( hits ))
210+ message = f'invalid number of hits: { len ( hits ) } '
211+ message = f' { message } (block: { block_index } , byte: { index } )'
200212 return Fail (block_index , message )
201213
202214 for byte in hits :
203- X1_test = [byte ^ padding , * X1 ]
204- run_block_task (ctx , block_index , C0 , C1 , X1_test )
205-
206- return Pass (block_index , index , byte ^ padding ^ C0 [index ])
215+ X1_test = [byte ^ padding , * X1_suffix ]
216+ add_solve_block_task (ctx , block_index , C0 , C1 , X1_test )
207217
208218
209219async def get_oracle_hits (ctx : Context , C0 : List [int ], C1 : List [int ],
@@ -228,28 +238,10 @@ async def get_oracle_hits(ctx: Context, C0: List[int], C1: List[int],
228238 return hits
229239
230240
231- def update_latest_plaintext (ctx : Context , block_index : int , index : int ,
232- byte : int ):
233-
234- i = (block_index - 1 ) * ctx .block_size + index
235- ctx .latest_plaintext [i ] = byte
236- ctx .plaintext_callback (ctx .latest_plaintext )
237-
238-
239- def update_plaintext (ctx : Context , block_index : int , C0 : List [int ],
240- X1 : List [int ]):
241-
242- assert len (C0 ) == len (X1 ) == ctx .block_size
243- block = compute_plaintext (C0 , X1 )
244-
245- i = (block_index - 1 ) * ctx .block_size
246- ctx .latest_plaintext [i :i + ctx .block_size ] = block
247- ctx .plaintext [i :i + ctx .block_size ] = block
248- ctx .plaintext_callback (ctx .plaintext )
249-
250-
251- def compute_plaintext (C0 : List [int ], X1 : List [int ]):
252- return [c ^ x for c , x in zip (C0 , X1 )]
241+ def update_plaintext (ctx : Context , block_index : int , solved_suffix : List [int ]):
242+ j = block_index * ctx .block_size
243+ i = j - len (solved_suffix )
244+ ctx .plaintext [i :j ] = solved_suffix
253245
254246
255247def convert_to_bytes (byte_list : List [int ], replacement = b' ' ):
0 commit comments