Skip to content

Commit 525255a

Browse files
committed
port kv_cache to new memory
1 parent 1d23ae0 commit 525255a

File tree

7 files changed

+72
-335
lines changed

7 files changed

+72
-335
lines changed

examples/notebooks/Batching.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@
230230
"outputs": [],
231231
"source": [
232232
"for i in range(n_parallel):\n",
233-
" llama_cpp.llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens)"
233+
" llama_cpp.llama_kv_self_seq_cp(ctx, 0, i, 0, batch.n_tokens)"
234234
]
235235
},
236236
{

llama_cpp/_ctypes_extensions.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import ctypes
66
import functools
77
import pathlib
8+
import logging
9+
import traceback
810

911
from typing import (
1012
Any,
@@ -18,6 +20,9 @@
1820
)
1921
from typing_extensions import TypeAlias
2022

23+
# Configure logging
24+
logging.basicConfig(level=logging.INFO)
25+
logger = logging.getLogger("llama_cpp.binding")
2126

2227
# Load the library
2328
def load_shared_library(lib_base_name: str, base_path: pathlib.Path):
@@ -110,11 +115,21 @@ def ctypes_function(
110115
):
111116
def decorator(f: F) -> F:
112117
if enabled:
118+
print(f"Setting up binding for C function: {name}") # Print when binding is created
113119
func = getattr(lib, name)
114120
func.argtypes = argtypes
115121
func.restype = restype
116-
functools.wraps(f)(func)
117-
return func
122+
123+
@functools.wraps(f)
124+
def wrapper(*args, **kwargs):
125+
print(f">>> Calling {name} with args: {args}") # Print right before C call
126+
sys.stdout.flush() # Force flush to ensure we see the output
127+
result = func(*args, **kwargs)
128+
print(f"<<< {name} returned successfully") # Print after successful return
129+
sys.stdout.flush()
130+
return result
131+
132+
return wrapper
118133
else:
119134
return f
120135

llama_cpp/_internals.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,20 +289,20 @@ def n_ctx(self) -> int:
289289
def pooling_type(self) -> int:
290290
return llama_cpp.llama_pooling_type(self.ctx)
291291

292-
def kv_cache_clear(self):
293-
llama_cpp.llama_kv_cache_clear(self.ctx)
292+
def kv_self_clear(self):
293+
llama_cpp.llama_kv_self_clear(self.ctx)
294294

295-
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
296-
llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
295+
def kv_self_seq_rm(self, seq_id: int, p0: int, p1: int):
296+
llama_cpp.llama_kv_self_seq_rm(self.ctx, seq_id, p0, p1)
297297

298-
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
299-
llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
298+
def kv_self_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
299+
llama_cpp.llama_kv_self_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
300300

301-
def kv_cache_seq_keep(self, seq_id: int):
302-
llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
301+
def kv_self_seq_keep(self, seq_id: int):
302+
llama_cpp.llama_kv_self_seq_keep(self.ctx, seq_id)
303303

304-
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
305-
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)
304+
def kv_self_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
305+
llama_cpp.llama_kv_self_seq_add(self.ctx, seq_id, p0, p1, shift)
306306

307307
def get_state_size(self) -> int:
308308
return llama_cpp.llama_get_state_size(self.ctx)

llama_cpp/llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def eval(self, tokens: Sequence[int]):
581581
Args:
582582
tokens: The list of tokens to evaluate.
583583
"""
584-
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
584+
self._ctx.kv_self_seq_rm(-1, self.n_tokens, -1)
585585
for i in range(0, len(tokens), self.n_batch):
586586
batch = tokens[i : min(len(tokens), i + self.n_batch)]
587587
n_past = self.n_tokens
@@ -889,7 +889,7 @@ def generate(
889889

890890
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
891891
self.n_tokens = sample_idx
892-
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
892+
self._ctx.kv_self_seq_rm(-1, self.n_tokens, -1)
893893
break
894894

895895
if self.draft_model is not None:
@@ -985,7 +985,7 @@ def embed(
985985
data: Union[List[List[float]], List[List[List[float]]]] = []
986986

987987
def decode_batch(seq_sizes: List[int]):
988-
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
988+
llama_cpp.llama_kv_self_clear(self._ctx.ctx)
989989
self._ctx.decode(self._batch)
990990
self._batch.reset()
991991

@@ -1056,7 +1056,7 @@ def decode_batch(seq_sizes: List[int]):
10561056

10571057
output = data[0] if isinstance(input, str) else data
10581058

1059-
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
1059+
llama_cpp.llama_kv_self_clear(self._ctx.ctx)
10601060
self.reset()
10611061

10621062
if return_count:

llama_cpp/llama_chat_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2847,7 +2847,7 @@ def __call__(
28472847

28482848
# Evaluate prompt
28492849
llama.reset()
2850-
llama._ctx.kv_cache_clear()
2850+
llama._ctx.kv_self_clear()
28512851
for type_, value in split_text:
28522852
if type_ == "text":
28532853
tokens = llama.tokenize(

0 commit comments

Comments
 (0)