@@ -3028,6 +3028,26 @@ def __call__(
30283028 )
30293029 return _convert_completion_to_chat (completion_or_chunks , stream = stream )
30303030
3031+ def eval_image (self , llama : llama .Llama , image_url : str ):
3032+ image_bytes = self .load_image (image_url )
3033+ embed = self ._embed_image_bytes (image_bytes , llama .context_params .n_threads_batch )
3034+ if llama .n_tokens + embed .contents .n_image_pos > llama .n_ctx ():
3035+ raise ValueError (
3036+ f"Prompt exceeds n_ctx: { llama .n_tokens + embed .contents .n_image_pos } > { llama .n_ctx ()} "
3037+ )
3038+ n_past = ctypes .c_int (llama .n_tokens )
3039+ n_past_p = ctypes .pointer (n_past )
3040+ with suppress_stdout_stderr (disable = self .verbose ):
3041+ self ._llava_cpp .llava_eval_image_embed (
3042+ llama .ctx ,
3043+ embed ,
3044+ llama .n_batch ,
3045+ n_past_p ,
3046+ )
3047+ # Required to avoid issues with hf tokenizer
3048+ llama .input_ids [llama .n_tokens : n_past .value ] = - 1
3049+ llama .n_tokens = n_past .value
3050+
30313051 @staticmethod
30323052 def _load_image (image_url : str ) -> bytes :
30333053 # TODO: Add Pillow support for other image formats beyond (jpg, png)
@@ -3582,10 +3602,10 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
35823602 if pos != - 1 :
35833603 assert len (copied_urls ) > 0
35843604 if pos > 0 :
3585- split_text += [( "text" , remaining [:pos ])]
3586- split_text += [( "text" , "\n \n <start_of_image>" )]
3587- split_text += [( "image_url" , copied_urls .pop (0 ))]
3588- split_text += [( "text" , "<end_of_image>\n \n " )]
3605+ split_text . append (( "text" , remaining [:pos ]))
3606+ split_text . append (( "text" , "\n \n <start_of_image>" ))
3607+ split_text . append (( "image_url" , copied_urls .pop (0 )))
3608+ split_text . append (( "text" , "<end_of_image>\n \n " ))
35893609 remaining = remaining [pos + len (image_placeholder ):]
35903610 else :
35913611 assert len (copied_urls ) == 0
@@ -3608,6 +3628,60 @@ def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
36083628 image_urls .append (content ["url" ])
36093629 return image_urls
36103630
3631+ def eval_image (self , llama : llama .Llama , image_url : str ):
3632+ import llama_cpp
3633+
3634+ img_bytes = self .load_image (image_url )
3635+ img_u8_p = self ._llava_cpp .clip_image_u8_init ()
3636+ if not self ._llava_cpp .clip_image_load_from_bytes (
3637+ ctypes .create_string_buffer (img_bytes , len (img_bytes )),
3638+ ctypes .c_size_t (len (img_bytes )),
3639+ img_u8_p ,
3640+ ):
3641+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3642+ raise ValueError ("Failed to load image." )
3643+
3644+ img_f32 = self ._llava_cpp .clip_image_f32_batch ()
3645+ img_f32_p = ctypes .byref (img_f32 )
3646+ if not self ._llava_cpp .clip_image_preprocess (self .clip_ctx , img_u8_p , img_f32_p ):
3647+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3648+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3649+ raise ValueError ("Failed to preprocess image." )
3650+
3651+ n_embd = llama_cpp .llama_model_n_embd (llama ._model .model )
3652+ n_tokens = 256
3653+ embed = (ctypes .c_float * (n_tokens * n_embd ))()
3654+ if not self ._llava_cpp .clip_image_batch_encode (self .clip_ctx , llama .n_threads , img_f32_p , embed ):
3655+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3656+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3657+ raise ValueError ("Failed to encode image." )
3658+
3659+ self ._llava_cpp .clip_image_f32_batch_free (img_f32_p )
3660+ self ._llava_cpp .clip_image_u8_free (img_u8_p )
3661+ llama_cpp .llama_set_causal_attn (llama .ctx , False )
3662+
3663+ seq_id_0 = (ctypes .c_int32 * 1 )()
3664+ seq_ids = (ctypes .POINTER (ctypes .c_int32 ) * (n_tokens + 1 ))()
3665+ for i in range (n_tokens ):
3666+ seq_ids [i ] = seq_id_0
3667+
3668+ batch = llama_cpp .llama_batch ()
3669+ batch .n_tokens = n_tokens
3670+ batch .token = None
3671+ batch .embd = embed
3672+ batch .pos = (ctypes .c_int32 * n_tokens )(* [i + llama .n_tokens for i in range (n_tokens )])
3673+ batch .seq_id = seq_ids
3674+ batch .n_seq_id = (ctypes .c_int32 * n_tokens )(* ([1 ] * n_tokens ))
3675+ batch .logits = (ctypes .c_int8 * n_tokens )()
3676+
3677+ if llama_cpp .llama_decode (llama .ctx , batch ):
3678+ raise ValueError ("Failed to decode image." )
3679+
3680+ llama_cpp .llama_set_causal_attn (llama .ctx , True )
3681+ # Required to avoid issues with hf tokenizer
3682+ llama .input_ids [llama .n_tokens : llama .n_tokens + n_tokens ] = - 1
3683+ llama .n_tokens += n_tokens
3684+
36113685
36123686def _accumulate_chunks (
36133687 chunks_iterator : Iterator [llama_types .CreateCompletionStreamResponse ],
0 commit comments