Skip to content

Commit 4f74f31

Browse files
committed
refactor: Implement robust flatMapLatest with generation tracking
- Replaces naive implementation with a thread-safe approach using ManagedCriticalState. - Introduces generation tracking to prevent race conditions where cancelled inner sequences could yield stale values. - Adds test_interleaving_race_condition to verify correctness under concurrent load. - Ensures Swift 6 Sendable compliance.
1 parent c5235f7 commit 4f74f31

File tree

2 files changed

+112
-18
lines changed

2 files changed

+112
-18
lines changed

Sources/AsyncAlgorithms/AsyncFlatMapLatestSequence.swift

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,43 +26,81 @@ extension AsyncSequence where Self: Sendable {
2626
) -> AsyncThrowingStream<T.Element, Error>
2727
where T.Element: Sendable {
2828

29-
AsyncThrowingStream { continuation in
30-
let outerIterationTask = Task {
31-
var innerIterationTask: Task<Void, Never>? = nil
32-
29+
// Explicitly specify the type of the stream
30+
return AsyncThrowingStream<T.Element, Error> { continuation in
31+
let state = ManagedCriticalState(FlatMapLatestState())
32+
33+
let outerTask = Task {
3334
do {
3435
for try await element in self {
35-
innerIterationTask?.cancel()
36-
3736
let innerSequence = transform(element)
3837

39-
innerIterationTask = Task {
38+
// Increment generation and get the new value
39+
let currentGeneration = state.withCriticalRegion { state -> Int in
40+
state.innerTask?.cancel()
41+
state.generation += 1
42+
return state.generation
43+
}
44+
45+
let innerTask = Task {
4046
do {
4147
for try await innerElement in innerSequence {
42-
try Task.checkCancellation()
43-
continuation.yield(innerElement)
48+
// Check if we are still the latest generation
49+
let shouldYield = state.withCriticalRegion { state in
50+
state.generation == currentGeneration
51+
}
52+
53+
if shouldYield {
54+
continuation.yield(innerElement)
55+
} else {
56+
// If we are not the latest, we should stop
57+
return
58+
}
4459
}
4560
} catch is CancellationError {
46-
// Inner task was cancelled, this is normal
61+
// Normal cancellation
4762
} catch {
48-
// Inner sequence threw an error
49-
continuation.finish(throwing: error)
63+
// If an error occurs, we only propagate it if we are the latest generation
64+
let shouldPropagate = state.withCriticalRegion { state in
65+
state.generation == currentGeneration
66+
}
67+
if shouldPropagate {
68+
continuation.finish(throwing: error)
69+
}
70+
}
71+
}
72+
73+
state.withCriticalRegion { state in
74+
// Only update the inner task if the generation hasn't changed again
75+
if state.generation == currentGeneration {
76+
state.innerTask = innerTask
5077
}
5178
}
5279
}
80+
81+
// Outer sequence finished
82+
// Wait for the last inner task to finish
83+
let lastInnerTask = state.withCriticalRegion { $0.innerTask }
84+
_ = await lastInnerTask?.result
85+
continuation.finish()
86+
5387
} catch {
54-
// Outer sequence threw an error
5588
continuation.finish(throwing: error)
5689
}
57-
58-
// Outer sequence finished
59-
await innerIterationTask?.value
60-
continuation.finish()
6190
}
6291

6392
continuation.onTermination = { @Sendable _ in
64-
outerIterationTask.cancel()
93+
outerTask.cancel()
94+
state.withCriticalRegion { state in
95+
state.innerTask?.cancel()
96+
}
6597
}
6698
}
6799
}
68100
}
101+
102+
@available(AsyncAlgorithms 1.0, *)
103+
private struct FlatMapLatestState: Sendable {
104+
var generation: Int = 0
105+
var innerTask: Task<Void, Never>? = nil
106+
}

Tests/AsyncAlgorithmsTests/TestFlatMapLatest.swift

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,60 @@ final class TestFlatMapLatest: XCTestCase {
3535
XCTAssertTrue(expected.isEmpty)
3636
}
3737

38+
func test_interleaving_race_condition() async throws {
39+
// This test simulates a scenario where the inner sequence is slow.
40+
// In a naive implementation (without generation tracking), the inner task for '1'
41+
// might wake up and yield AFTER '2' has already started, causing interleaving.
42+
43+
let source = [1, 2, 3].async
44+
let transformed = source.flatMapLatest { intValue -> AsyncStream<Int> in
45+
AsyncStream { continuation in
46+
Task {
47+
// Yield the value immediately
48+
continuation.yield(intValue)
49+
50+
// Sleep for a bit to allow the outer sequence to move on
51+
try? await Task.sleep(nanoseconds: 10_000_000) // 10ms
52+
53+
// Yield a second value - this should be ignored if a new outer value has arrived
54+
continuation.yield(intValue * 10)
55+
continuation.finish()
56+
}
57+
}
58+
}
59+
60+
// We expect:
61+
// 1 arrives -> starts inner(1) -> yields 1 -> sleeps
62+
// 2 arrives -> cancels inner(1) -> starts inner(2) -> yields 2 -> sleeps
63+
// 3 arrives -> cancels inner(2) -> starts inner(3) -> yields 3 -> sleeps
64+
// inner(3) finishes sleep -> yields 30
65+
//
66+
// Ideally, we should NOT see 10 or 20.
67+
// However, without strict synchronization, we might see them.
68+
// The strict expectation for flatMapLatest is that once a new value arrives,
69+
// the old one produces NO MORE values.
70+
71+
// Note: This test is probabilistic in the naive implementation.
72+
// It might pass or fail depending on scheduling.
73+
// But with a correct implementation, it should ALWAYS pass.
74+
75+
var expected = [3, 30] // We only want the latest
76+
77+
// We'll collect all results to see what happened
78+
var results: [Int] = []
79+
80+
for try await element in transformed {
81+
results.append(element)
82+
}
83+
84+
// In the naive implementation, we might get [1, 2, 3, 10, 20, 30] or similar.
85+
// We want strictly [3, 30] (or [1, 2, 3, 30] depending on how fast the outer sequence is consumed vs produced)
86+
// Actually, if the outer sequence is consumed fast, we might see intermediate "first" values (1, 2).
87+
// But we should NEVER see "second" values (10, 20) from cancelled sequences.
88+
89+
// Let's relax the check to: "Must not contain 10 or 20"
90+
XCTAssertFalse(results.contains(10), "Should not contain 10 (from cancelled sequence 1)")
91+
XCTAssertFalse(results.contains(20), "Should not contain 20 (from cancelled sequence 2)")
92+
XCTAssertTrue(results.contains(30), "Should contain 30 (from final sequence)")
93+
}
3894
}

0 commit comments

Comments
 (0)