Skip to content

Commit dcd9d1b

Browse files
fix: resolve concurrency issues in DefaultCmabService with lock striping
1 parent efbda89 commit dcd9d1b

File tree

2 files changed

+115
-40
lines changed

2 files changed

+115
-40
lines changed

core-api/src/main/java/com/optimizely/ab/cmab/service/DefaultCmabService.java

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.List;
2121
import java.util.Map;
2222
import java.util.TreeMap;
23+
import java.util.concurrent.locks.ReentrantLock;
2324

2425
import org.slf4j.Logger;
2526
import org.slf4j.LoggerFactory;
@@ -37,10 +38,12 @@
3738
public class DefaultCmabService implements CmabService {
3839
public static final int DEFAULT_CMAB_CACHE_SIZE = 10000;
3940
public static final int DEFAULT_CMAB_CACHE_TIMEOUT_SECS = 30*60; // 30 minutes
41+
private static final int NUM_LOCK_STRIPES = 1000;
4042

4143
private final Cache<CmabCacheValue> cmabCache;
4244
private final CmabClient cmabClient;
4345
private final Logger logger;
46+
private final ReentrantLock[] locks;
4447

4548
public DefaultCmabService(CmabClient cmabClient, Cache<CmabCacheValue> cmabCache) {
4649
this(cmabClient, cmabCache, null);
@@ -50,52 +53,64 @@ public DefaultCmabService(CmabClient cmabClient, Cache<CmabCacheValue> cmabCache
5053
this.cmabCache = cmabCache;
5154
this.cmabClient = cmabClient;
5255
this.logger = logger != null ? logger : LoggerFactory.getLogger(DefaultCmabService.class);
56+
this.locks = new ReentrantLock[NUM_LOCK_STRIPES];
57+
for (int i = 0; i < NUM_LOCK_STRIPES; i++) {
58+
this.locks[i] = new ReentrantLock();
59+
}
5360
}
5461

5562
@Override
5663
public CmabDecision getDecision(ProjectConfig projectConfig, OptimizelyUserContext userContext, String ruleId, List<OptimizelyDecideOption> options) {
5764
options = options == null ? Collections.emptyList() : options;
5865
String userId = userContext.getUserId();
59-
Map<String, Object> filteredAttributes = filterAttributes(projectConfig, userContext, ruleId);
6066

61-
if (options.contains(OptimizelyDecideOption.IGNORE_CMAB_CACHE)) {
62-
logger.debug("Ignoring CMAB cache for user '{}' and rule '{}'", userId, ruleId);
63-
return fetchDecision(ruleId, userId, filteredAttributes);
64-
}
67+
int lockIndex = getLockIndex(userId, ruleId);
68+
ReentrantLock lock = locks[lockIndex];
69+
lock.lock();
70+
try {
71+
Map<String, Object> filteredAttributes = filterAttributes(projectConfig, userContext, ruleId);
6572

66-
if (options.contains(OptimizelyDecideOption.RESET_CMAB_CACHE)) {
67-
logger.debug("Resetting CMAB cache for user '{}' and rule '{}'", userId, ruleId);
68-
cmabCache.reset();
69-
}
73+
if (options.contains(OptimizelyDecideOption.IGNORE_CMAB_CACHE)) {
74+
logger.debug("Ignoring CMAB cache for user '{}' and rule '{}'", userId, ruleId);
75+
return fetchDecision(ruleId, userId, filteredAttributes);
76+
}
7077

71-
String cacheKey = getCacheKey(userContext.getUserId(), ruleId);
72-
if (options.contains(OptimizelyDecideOption.INVALIDATE_USER_CMAB_CACHE)) {
73-
logger.debug("Invalidating CMAB cache for user '{}' and rule '{}'", userId, ruleId);
74-
cmabCache.remove(cacheKey);
75-
}
78+
if (options.contains(OptimizelyDecideOption.RESET_CMAB_CACHE)) {
79+
logger.debug("Resetting CMAB cache for user '{}' and rule '{}'", userId, ruleId);
80+
cmabCache.reset();
81+
}
82+
83+
String cacheKey = getCacheKey(userContext.getUserId(), ruleId);
84+
if (options.contains(OptimizelyDecideOption.INVALIDATE_USER_CMAB_CACHE)) {
85+
logger.debug("Invalidating CMAB cache for user '{}' and rule '{}'", userId, ruleId);
86+
cmabCache.remove(cacheKey);
87+
}
7688

77-
CmabCacheValue cachedValue = cmabCache.lookup(cacheKey);
89+
CmabCacheValue cachedValue = cmabCache.lookup(cacheKey);
7890

79-
String attributesHash = hashAttributes(filteredAttributes);
91+
String attributesHash = hashAttributes(filteredAttributes);
8092

81-
if (cachedValue != null) {
82-
if (cachedValue.getAttributesHash().equals(attributesHash)) {
83-
logger.debug("CMAB cache hit for user '{}' and rule '{}'", userId, ruleId);
84-
return new CmabDecision(cachedValue.getVariationId(), cachedValue.getCmabUuid());
93+
if (cachedValue != null) {
94+
if (cachedValue.getAttributesHash().equals(attributesHash)) {
95+
logger.debug("CMAB cache hit for user '{}' and rule '{}'", userId, ruleId);
96+
return new CmabDecision(cachedValue.getVariationId(), cachedValue.getCmabUuid());
97+
} else {
98+
logger.debug("CMAB cache attributes mismatch for user '{}' and rule '{}', fetching new decision", userId, ruleId);
99+
cmabCache.remove(cacheKey);
100+
}
85101
} else {
86-
logger.debug("CMAB cache attributes mismatch for user '{}' and rule '{}', fetching new decision", userId, ruleId);
87-
cmabCache.remove(cacheKey);
102+
logger.debug("CMAB cache miss for user '{}' and rule '{}'", userId, ruleId);
88103
}
89-
} else {
90-
logger.debug("CMAB cache miss for user '{}' and rule '{}'", userId, ruleId);
91-
}
92104

93-
CmabDecision cmabDecision = fetchDecision(ruleId, userId, filteredAttributes);
94-
logger.debug("CMAB decision is {}", cmabDecision);
95-
96-
cmabCache.save(cacheKey, new CmabCacheValue(attributesHash, cmabDecision.getVariationId(), cmabDecision.getCmabUUID()));
105+
CmabDecision cmabDecision = fetchDecision(ruleId, userId, filteredAttributes);
106+
logger.debug("CMAB decision is {}", cmabDecision);
97107

98-
return cmabDecision;
108+
cmabCache.save(cacheKey, new CmabCacheValue(attributesHash, cmabDecision.getVariationId(), cmabDecision.getCmabUUID()));
109+
110+
return cmabDecision;
111+
} finally {
112+
lock.unlock();
113+
}
99114
}
100115

101116
private CmabDecision fetchDecision(String ruleId, String userId, Map<String, Object> attributes) {
@@ -192,6 +207,13 @@ private String hashAttributes(Map<String, Object> attributes) {
192207
return Integer.toHexString(hash);
193208
}
194209

210+
private int getLockIndex(String userId, String ruleId) {
211+
// Create a hash of userId + ruleId for consistent lock selection
212+
String combined = userId + ruleId;
213+
int hash = MurmurHash3.murmurhash3_x86_32(combined, 0, combined.length(), 0);
214+
return Math.abs(hash) % NUM_LOCK_STRIPES;
215+
}
216+
195217
public static Builder builder() {
196218
return new Builder();
197219
}

core-api/src/test/java/com/optimizely/ab/cmab/DefaultCmabServiceTest.java

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,14 @@
1515
*/
1616
package com.optimizely.ab.cmab;
1717

18-
import java.util.Arrays;
19-
import java.util.Collections;
20-
import java.util.HashMap;
21-
import java.util.LinkedHashMap;
22-
import java.util.List;
23-
import java.util.Map;
24-
25-
import static org.junit.Assert.assertEquals;
26-
import static org.junit.Assert.assertNotNull;
18+
import java.lang.reflect.Method;
19+
import java.util.*;
20+
2721
import org.junit.Before;
2822
import org.junit.Test;
2923
import org.mockito.ArgumentCaptor;
24+
25+
import static org.junit.Assert.*;
3026
import static org.mockito.Matchers.any;
3127
import static org.mockito.Matchers.anyString;
3228
import static org.mockito.Matchers.eq;
@@ -375,4 +371,61 @@ public void testAttributeOrderDoesNotMatterForCaching() {
375371
assertNotNull(decision.getCmabUUID());
376372
verify(mockCmabCache).save(eq(cacheKey), any(CmabCacheValue.class));
377373
}
378-
}
374+
@Test
375+
public void testLockStripingDistribution() {
376+
// Test different combinations to ensure they get different lock indices
377+
String[][] testCases = {
378+
{"user1", "rule1"},
379+
{"user2", "rule1"},
380+
{"user1", "rule2"},
381+
{"user3", "rule3"},
382+
{"user4", "rule4"}
383+
};
384+
385+
Set<Integer> lockIndices = new HashSet<>();
386+
for (String[] testCase : testCases) {
387+
String userId = testCase[0];
388+
String ruleId = testCase[1];
389+
390+
// Use reflection to access the private getLockIndex method
391+
try {
392+
Method getLockIndexMethod = DefaultCmabService.class.getDeclaredMethod("getLockIndex", String.class, String.class);
393+
getLockIndexMethod.setAccessible(true);
394+
395+
int index = (Integer) getLockIndexMethod.invoke(cmabService, userId, ruleId);
396+
397+
// Verify index is within expected range
398+
assertTrue("Lock index should be non-negative", index >= 0);
399+
assertTrue("Lock index should be less than NUM_LOCK_STRIPES", index < 1000);
400+
401+
lockIndices.add(index);
402+
} catch (Exception e) {
403+
fail("Failed to invoke getLockIndex method: " + e.getMessage());
404+
}
405+
}
406+
407+
assertTrue("Different user/rule combinations should generally use different locks", lockIndices.size() > 1);
408+
}
409+
410+
@Test
411+
public void testSameUserRuleCombinationUsesConsistentLock() {
412+
String userId = "test_user";
413+
String ruleId = "test_rule";
414+
415+
try {
416+
Method getLockIndexMethod = DefaultCmabService.class.getDeclaredMethod("getLockIndex", String.class, String.class);
417+
getLockIndexMethod.setAccessible(true);
418+
419+
// Get lock index multiple times
420+
int index1 = (Integer) getLockIndexMethod.invoke(cmabService, userId, ruleId);
421+
int index2 = (Integer) getLockIndexMethod.invoke(cmabService, userId, ruleId);
422+
int index3 = (Integer) getLockIndexMethod.invoke(cmabService, userId, ruleId);
423+
424+
// All should be the same
425+
assertEquals("Same user/rule should always use same lock", index1, index2);
426+
assertEquals("Same user/rule should always use same lock", index2, index3);
427+
} catch (Exception e) {
428+
fail("Failed to invoke getLockIndex method: " + e.getMessage());
429+
}
430+
}
431+
}

0 commit comments

Comments
 (0)