diff --git a/lagent/llms/lmdeploy_wrapper.py b/lagent/llms/lmdeploy_wrapper.py index b5813e7..0ed64bc 100644 --- a/lagent/llms/lmdeploy_wrapper.py +++ b/lagent/llms/lmdeploy_wrapper.py @@ -248,7 +248,7 @@ def __init__(self, '`do_sample` parameter is not supported by lmdeploy until ' f'v0.6.0, but currently using lmdeloy {self.str_version}') super().__init__(path=path, **kwargs) - backend_config = copy.deepcopy(pipeline_cfg) + backend_config = dict(copy.deepcopy(pipeline_cfg)) backend_config.update(tp=tp) backend_config = { k: v diff --git a/lagent/llms/openai.py b/lagent/llms/openai.py index 7418a65..d1a00de 100644 --- a/lagent/llms/openai.py +++ b/lagent/llms/openai.py @@ -358,7 +358,7 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + max_tokens = gen_params.pop('max_new_tokens') if max_tokens <= 0: return '', '' @@ -374,23 +374,19 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals if 'repetition_penalty' in gen_params: gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty') - # Model-specific processing data = {} - if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'): + if any(x in model_type.lower() for x in ['o1', 'o3', 'o4']): + data = {'model': model_type, 'messages': messages, 'n': 1} + else: if 'top_k' in gen_params: warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) gen_params.pop('top_k') gen_params.pop('skip_special_tokens', None) gen_params.pop('session_id', None) + data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} if json_mode: data['response_format'] = {'type': 'json_object'} - elif model_type.lower().startswith('internlm'): - data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} - if json_mode: - data['response_format'] = {'type': 'json_object'} - else: - raise NotImplementedError(f'Model type {model_type} is not supported') return header, data @@ -756,7 +752,7 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals gen_params = gen_params.copy() # Hold out 100 tokens due to potential errors in token calculation - max_tokens = min(gen_params.pop('max_new_tokens'), 4096) + max_tokens = gen_params.pop('max_new_tokens') if max_tokens <= 0: return '', '' @@ -772,9 +768,10 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals if 'repetition_penalty' in gen_params: gen_params['frequency_penalty'] = gen_params.pop('repetition_penalty') - # Model-specific processing data = {} - if model_type.lower().startswith('gpt') or model_type.lower().startswith('qwen'): + if any(x in model_type.lower() for x in ['o1', 'o3', 'o4']): + data = {'model': model_type, 'messages': messages, 'n': 1} + else: if 'top_k' in gen_params: warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.', DeprecationWarning) gen_params.pop('top_k') @@ -784,14 +781,6 @@ def generate_request_data(self, model_type, messages, gen_params, json_mode=Fals data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} if json_mode: data['response_format'] = {'type': 'json_object'} - elif model_type.lower().startswith('internlm'): - data = {'model': model_type, 'messages': messages, 'n': 1, **gen_params} - if json_mode: - data['response_format'] = {'type': 'json_object'} - elif model_type.lower().startswith('o1'): - data = {'model': model_type, 'messages': messages, 'n': 1} - else: - raise NotImplementedError(f'Model type {model_type} is not supported') return header, data