From cda561950a0cc1430fd3fc617a16713216e86033 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Fri, 23 May 2025 11:28:53 -0700 Subject: [PATCH 1/2] MockMemcacheClient: implement gets, gets_many, cas --- pymemcache/test/test_client.py | 73 +++++++++++++++++++++++++--------- pymemcache/test/utils.py | 43 ++++++++++++++++++-- 2 files changed, 94 insertions(+), 22 deletions(-) diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 9b38394d..9e6a47f6 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -1511,18 +1511,11 @@ def increment(self, obj): class TestMockClient(ClientTestMixin, unittest.TestCase): - def make_client(self, mock_socket_values, **kwargs): - client = MockMemcacheClient("localhost", **kwargs) - client.sock = MockSocket(list(mock_socket_values)) - return client + def make_client(self, mock_socket_values=None, **kwargs): + return MockMemcacheClient("localhost", **kwargs) def test_get_found(self): - client = self.make_client( - [ - b"STORED\r\n", - b"VALUE key 0 5\r\nvalue\r\nEND\r\n", - ] - ) + client = self.make_client() result = client.set(b"key", b"value", noreply=False) result = client.get(b"key") assert result == b"value" @@ -1539,15 +1532,7 @@ def deserialize(self, key, value, flags): return json.loads(value.decode("UTF-8")) return value - client = self.make_client( - [ - b"STORED\r\n", - b"VALUE key1 0 5\r\nhello\r\nEND\r\n", - b"STORED\r\n", - b'VALUE key2 0 18\r\n{"hello": "world"}\r\nEND\r\n', - ], - serde=JsonSerde(), - ) + client = self.make_client(serde=JsonSerde()) result = client.set(b"key1", b"hello", noreply=False) result = client.get(b"key1") @@ -1557,6 +1542,56 @@ def deserialize(self, key, value, flags): result = client.get(b"key2") assert result == dict(hello="world") + def test_gets_not_found(self): + client = self.make_client() + result = client.gets(b"key") + assert result == (None, None) + + def test_gets_not_found_defaults(self): + client = self.make_client() + result = client.gets(b"key", default="foo", cas_default="bar") + assert result == ("foo", "bar") + + @mock.patch('time.time_ns', return_value=10) + def test_gets_found(self, _): + client = self.make_client() + result = client.set(b"key", b"value", noreply=False) + result = client.gets(b"key") + assert result == (b"value", b"10") + + def test_gets_many_none_found(self): + client = self.make_client([b"END\r\n"]) + result = client.gets_many([b"key1", b"key2"]) + assert result == {} + + @mock.patch('time.time_ns', return_value=11) + def test_gets_many_some_found(self, _): + client = self.make_client() + result = client.set(b"key1", b"value", noreply=False) + result = client.gets_many([b"key1", b"key2"]) + assert result == {b"key1": (b"value", b"11")} + + @mock.patch('time.time_ns', return_value=123) + def test_cas_stored(self, _): + client = self.make_client() + client.set(b"key", b"existing") + result = client.cas(b"key", b"value", b"123", noreply=False) + assert result is True + + result = client.get(b"key") + assert result == b"value" + + def test_cas_exists(self): + client = self.make_client() + client.set(b"key", b"existing") + result = client.cas(b"key", b"value", b"123", noreply=False) + assert result is False + + def test_cas_not_found(self): + client = self.make_client() + result = client.cas(b"key", b"value", b"123", noreply=False) + assert result is None + class TestPrefixedClient(ClientTestMixin, unittest.TestCase): def make_client(self, mock_socket_values, **kwargs): diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py index 52b17321..e5ccdf9f 100644 --- a/pymemcache/test/utils.py +++ b/pymemcache/test/utils.py @@ -38,6 +38,7 @@ def __init__( **kwargs, ): self._contents = {} + self._cas_ids = {} # maps keys to bytes CAS tokens def _serializer(key, value): if isinstance(value, str): @@ -68,6 +69,7 @@ def check_key(self, key): def clear(self): """Method used to clear/reset mock cache""" self._contents.clear() + self._cas_ids.clear() def get(self, key, default=None): key = self.check_key(key) @@ -92,6 +94,28 @@ def get_many(self, keys): get_multi = get_many + def gets(self, key, default=None, cas_default=None): + not_found = [] + + value = self.get(key, default=not_found) + if value is not_found: + return default, cas_default + + cas_token = self._cas_ids.setdefault(key, str(time.time_ns()).encode()) + return value, cas_token + + def gets_many(self, keys): + not_found = [] + + out = {} + for key in keys: + value, cas = self.gets(key, default=not_found) + if value is not not_found: + out[key] = (value, cas) + return out + + get_multi = get_many + def set(self, key, value, expire=0, noreply=True, flags=None): key = self.check_key(key) if isinstance(value, str) and not isinstance(value, bytes): @@ -106,6 +130,7 @@ def set(self, key, value, expire=0, noreply=True, flags=None): expire += time.time() self._contents[key] = expire, value, flags + self._cas_ids[key] = str(time.time_ns()).encode() return True def set_many(self, values, expire=0, noreply=True, flags=None): @@ -189,7 +214,7 @@ def stats(self, *_args): "stat_key_prefix": "", "umask": 0o644, "detail_enabled": False, - "cas_enabled": False, + "cas_enabled": True, "auth_enabled_sasl": False, "maxconns_fast": False, "slab_reassign": False, @@ -203,8 +228,20 @@ def replace(self, key, value, expire=0, noreply=True, flags=None): self.set(key, value, expire, noreply, flags=flags) return noreply or present - def cas(self, key, value, cas, expire=0, noreply=False, flags=None): - raise MemcacheClientError("CAS is not enabled for this instance") + def cas(self, key, value, cas_token, expire=0, noreply=False, **kwargs): + if not isinstance(cas_token, (int, str, bytes)): + raise MemcacheIllegalInputError(f'cas must be integer, string, or bytes, got bad value: {cas_token}') + + key = self.check_key(key) + + if key not in self._contents: + self.set(key, value, noreply=noreply, **kwargs) + return True if noreply else None + + elif self._cas_ids.get(key) != cas_token: + return True if noreply else False + + return self.set(key, value, noreply=noreply, **kwargs) def touch(self, key, expire=0, noreply=True): current = self.get(key) From b42771638f19c86aa7f220fc6e1798e8ac682bd7 Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Fri, 23 May 2025 14:18:16 -0700 Subject: [PATCH 2/2] MockMemcacheClient.cas: don't set the value if the key doesn't exist --- pymemcache/test/test_client.py | 3 +++ pymemcache/test/utils.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index 9e6a47f6..f8e28197 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -1592,6 +1592,9 @@ def test_cas_not_found(self): result = client.cas(b"key", b"value", b"123", noreply=False) assert result is None + result = client.get(b"key") + assert result is None + class TestPrefixedClient(ClientTestMixin, unittest.TestCase): def make_client(self, mock_socket_values, **kwargs): diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py index e5ccdf9f..4c750d76 100644 --- a/pymemcache/test/utils.py +++ b/pymemcache/test/utils.py @@ -235,7 +235,6 @@ def cas(self, key, value, cas_token, expire=0, noreply=False, **kwargs): key = self.check_key(key) if key not in self._contents: - self.set(key, value, noreply=noreply, **kwargs) return True if noreply else None elif self._cas_ids.get(key) != cas_token: