Skip to content

Commit 7bdfb83

Browse files
kossumokaris
authored andcommitted
resolve the image embedding issue in gemma3
1 parent 60a70de commit 7bdfb83

File tree

2 files changed

+190
-4
lines changed

2 files changed

+190
-4
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,6 +3028,26 @@ def __call__(
30283028
)
30293029
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
30303030

3031+
def eval_image(self, llama: llama.Llama, image_url: str):
3032+
image_bytes = self.load_image(image_url)
3033+
embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch)
3034+
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
3035+
raise ValueError(
3036+
f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}"
3037+
)
3038+
n_past = ctypes.c_int(llama.n_tokens)
3039+
n_past_p = ctypes.pointer(n_past)
3040+
with suppress_stdout_stderr(disable=self.verbose):
3041+
self._llava_cpp.llava_eval_image_embed(
3042+
llama.ctx,
3043+
embed,
3044+
llama.n_batch,
3045+
n_past_p,
3046+
)
3047+
# Required to avoid issues with hf tokenizer
3048+
llama.input_ids[llama.n_tokens : n_past.value] = -1
3049+
llama.n_tokens = n_past.value
3050+
30313051
@staticmethod
30323052
def _load_image(image_url: str) -> bytes:
30333053
# TODO: Add Pillow support for other image formats beyond (jpg, png)
@@ -3582,10 +3602,10 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
35823602
if pos != -1:
35833603
assert len(copied_urls) > 0
35843604
if pos > 0:
3585-
split_text += [("text", remaining[:pos])]
3586-
split_text += [("text", "\n\n<start_of_image>")]
3587-
split_text += [("image_url", copied_urls.pop(0))]
3588-
split_text += [("text", "<end_of_image>\n\n")]
3605+
split_text.append(("text", remaining[:pos]))
3606+
split_text.append(("text", "\n\n<start_of_image>"))
3607+
split_text.append(("image_url", copied_urls.pop(0)))
3608+
split_text.append(("text", "<end_of_image>\n\n"))
35893609
remaining = remaining[pos + len(image_placeholder):]
35903610
else:
35913611
assert len(copied_urls) == 0
@@ -3608,6 +3628,60 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
36083628
image_urls.append(content["url"])
36093629
return image_urls
36103630

3631+
def eval_image(self, llama: llama.Llama, image_url: str):
3632+
import llama_cpp
3633+
3634+
img_bytes = self.load_image(image_url)
3635+
img_u8_p = self._llava_cpp.clip_image_u8_init()
3636+
if not self._llava_cpp.clip_image_load_from_bytes(
3637+
ctypes.create_string_buffer(img_bytes, len(img_bytes)),
3638+
ctypes.c_size_t(len(img_bytes)),
3639+
img_u8_p,
3640+
):
3641+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3642+
raise ValueError("Failed to load image.")
3643+
3644+
img_f32 = self._llava_cpp.clip_image_f32_batch()
3645+
img_f32_p = ctypes.byref(img_f32)
3646+
if not self._llava_cpp.clip_image_preprocess(self.clip_ctx, img_u8_p, img_f32_p):
3647+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3648+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3649+
raise ValueError("Failed to preprocess image.")
3650+
3651+
n_embd = llama_cpp.llama_model_n_embd(llama._model.model)
3652+
n_tokens = 256
3653+
embed = (ctypes.c_float * (n_tokens * n_embd))()
3654+
if not self._llava_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed):
3655+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3656+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3657+
raise ValueError("Failed to encode image.")
3658+
3659+
self._llava_cpp.clip_image_f32_batch_free(img_f32_p)
3660+
self._llava_cpp.clip_image_u8_free(img_u8_p)
3661+
llama_cpp.llama_set_causal_attn(llama.ctx, False)
3662+
3663+
seq_id_0 = (ctypes.c_int32 * 1)()
3664+
seq_ids = (ctypes.POINTER(ctypes.c_int32) * (n_tokens + 1))()
3665+
for i in range(n_tokens):
3666+
seq_ids[i] = seq_id_0
3667+
3668+
batch = llama_cpp.llama_batch()
3669+
batch.n_tokens = n_tokens
3670+
batch.token = None
3671+
batch.embd = embed
3672+
batch.pos = (ctypes.c_int32 * n_tokens)(*[i + llama.n_tokens for i in range(n_tokens)])
3673+
batch.seq_id = seq_ids
3674+
batch.n_seq_id = (ctypes.c_int32 * n_tokens)(*([1] * n_tokens))
3675+
batch.logits = (ctypes.c_int8 * n_tokens)()
3676+
3677+
if llama_cpp.llama_decode(llama.ctx, batch):
3678+
raise ValueError("Failed to decode image.")
3679+
3680+
llama_cpp.llama_set_causal_attn(llama.ctx, True)
3681+
# Required to avoid issues with hf tokenizer
3682+
llama.input_ids[llama.n_tokens : llama.n_tokens + n_tokens] = -1
3683+
llama.n_tokens += n_tokens
3684+
36113685

36123686
def _accumulate_chunks(
36133687
chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse],

llama_cpp/llava_cpp.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
c_int,
88
c_uint8,
99
c_float,
10+
c_size_t,
1011
c_void_p,
1112
POINTER,
1213
_Pointer, # type: ignore
@@ -141,6 +142,28 @@ def llava_eval_image_embed(
141142
################################################
142143

143144

145+
# struct clip_image_u8_batch {
146+
# struct clip_image_u8 * data;
147+
# size_t size;
148+
# };
149+
class clip_image_u8_batch(Structure):
150+
_fields_ = [
151+
("data", c_void_p),
152+
("size", c_size_t),
153+
]
154+
155+
156+
# struct clip_image_f32_batch {
157+
# struct clip_image_f32 * data;
158+
# size_t size;
159+
# };
160+
class clip_image_f32_batch(Structure):
161+
_fields_ = [
162+
("data", c_void_p),
163+
("size", c_size_t),
164+
]
165+
166+
144167
# /** load mmproj model */
145168
# CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
146169
@ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes)
@@ -156,3 +179,92 @@ def clip_model_load(
156179
def clip_free(ctx: clip_ctx_p, /):
157180
...
158181

182+
183+
# CLIP_API struct clip_image_u8 * clip_image_u8_init ();
184+
@ctypes_function("clip_image_u8_init", [], c_void_p)
185+
def clip_image_u8_init() -> Optional[c_void_p]:
186+
...
187+
188+
189+
# CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
190+
@ctypes_function("clip_image_u8_free", [c_void_p], None)
191+
def clip_image_u8_free(img: c_void_p, /):
192+
...
193+
194+
195+
# CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
196+
@ctypes_function("clip_image_f32_free", [c_void_p], None)
197+
def clip_image_f32_free(img: c_void_p, /):
198+
...
199+
200+
201+
# CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
202+
@ctypes_function("clip_image_u8_batch_free", [POINTER(clip_image_u8_batch)], None)
203+
def clip_image_u8_batch_free(batch: "_Pointer[clip_image_u8_batch]", /):
204+
...
205+
206+
207+
# CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
208+
@ctypes_function("clip_image_f32_batch_free", [POINTER(clip_image_f32_batch)], None)
209+
def clip_image_f32_batch_free(batch: "_Pointer[clip_image_f32_batch]", /):
210+
...
211+
212+
213+
# /** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
214+
# CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
215+
@ctypes_function(
216+
"clip_image_preprocess",
217+
[
218+
clip_ctx_p_ctypes,
219+
c_void_p,
220+
POINTER(clip_image_f32_batch),
221+
],
222+
c_bool,
223+
)
224+
def clip_image_preprocess(
225+
ctx: clip_ctx_p,
226+
img: c_void_p,
227+
res_imgs: "_Pointer[clip_image_f32_batch]",
228+
/,
229+
) -> bool:
230+
...
231+
232+
233+
# CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
234+
@ctypes_function(
235+
"clip_image_batch_encode",
236+
[
237+
clip_ctx_p_ctypes,
238+
c_int,
239+
POINTER(clip_image_f32_batch),
240+
POINTER(c_float),
241+
],
242+
c_bool,
243+
)
244+
def clip_image_batch_encode(
245+
ctx: clip_ctx_p,
246+
n_threads: c_int,
247+
imgs: "_Pointer[clip_image_f32_batch]",
248+
vec: c_void_p
249+
) -> bool:
250+
...
251+
252+
253+
# /** interpret bytes as an image file with length bytes_length, and use the result to populate img */
254+
# CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
255+
@ctypes_function(
256+
"clip_image_load_from_bytes",
257+
[
258+
c_void_p,
259+
c_size_t,
260+
c_void_p,
261+
],
262+
c_bool,
263+
)
264+
def clip_image_load_from_bytes(
265+
bytes: c_void_p,
266+
bytes_length: c_size_t,
267+
img: c_void_p,
268+
/,
269+
) -> bool:
270+
...

0 commit comments

Comments
 (0)