Skip to content

Commit dde2384

Browse files
daniel-jones-devJanEbbing
authored andcommitted
feat: add model_type parameter to translate_text()
1 parent 7e1f8cd commit dde2384

File tree

7 files changed

+91
-4
lines changed

7 files changed

+91
-4
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ There are additional optional arguments to control translation, see
9999
corresponding to your input text(s). `TextResult` has the following properties:
100100
- `text` is the translated text,
101101
- `detected_source_lang` is the detected source language code,
102-
- `billed_characters` is the number of characters billed for the translation.
102+
- `billed_characters` is the number of characters billed for the translation.
103+
- `model_type_used` indicates the translation model used, but is `None` unless
104+
the `model_type` option is specified.
103105

104106
```python
105107
# Translate text into a target language, in this case, French:
@@ -162,6 +164,14 @@ arguments are:
162164
translated itself. Characters in the `context` parameter are not counted toward billing.
163165
See the [API documentation][api-docs-context-param] for more information and
164166
example usage.
167+
- `model_type`: specifies the type of translation model to use, options are:
168+
- `'quality_optimized'` (`ModelType.QUALITY_OPTIMIZED`): use a translation
169+
model that maximizes translation quality, at the cost of response time.
170+
This option may be unavailable for some language pairs.
171+
- `'prefer_quality_optimized'` (`ModelType.PREFER_QUALITY_OPTIMIZED`): use
172+
the highest-quality translation model for the given language pair.
173+
- `'latency_optimized'` (`ModelType.LATENCY_OPTIMIZED`): use a translation
174+
model that minimizes response time, at the cost of translation quality.
165175
- `tag_handling`: type of tags to parse before translation, options are `'html'`
166176
and `'xml'`.
167177

deepl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Formality,
2626
GlossaryInfo,
2727
Language,
28+
ModelType,
2829
SplitSentences,
2930
TextResult,
3031
Translator,

deepl/__main__.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def action_text(
7979
translator: deepl.Translator,
8080
show_detected_source: bool = False,
8181
show_billed_characters: Optional[bool] = None,
82+
show_model_type_used: Optional[bool] = None,
8283
**kwargs,
8384
):
8485
"""Action function for the text command."""
@@ -93,9 +94,17 @@ def action_text(
9394
text_value = (
9495
"unknown"
9596
if output.billed_characters is None
96-
else output.billed_characters
97+
else str(output.billed_characters)
9798
)
9899
print(f"Billed characters: {text_value}")
100+
if show_model_type_used:
101+
text_value = (
102+
"unknown"
103+
if output.model_type_used is None
104+
else output.model_type_used
105+
)
106+
print(f"Model type used: {text_value}")
107+
99108
print(output.text)
100109

101110

@@ -325,6 +334,19 @@ def add_common_arguments(subparser: argparse.ArgumentParser):
325334
action="store_true",
326335
help="print billed characters for each text",
327336
)
337+
parser_text.add_argument(
338+
"--show-model-type-used",
339+
dest="show_model_type_used",
340+
action="store_true",
341+
help="print the model type used for each text",
342+
)
343+
parser_text.add_argument(
344+
"--model-type",
345+
type=str,
346+
choices=[enum.value for enum in deepl.ModelType],
347+
default=None,
348+
help="control model used for translation, see API for information",
349+
)
328350
parser_text.add_argument(
329351
"text",
330352
nargs="+",

deepl/api_data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ def __init__(
1717
text: str,
1818
detected_source_lang: str,
1919
billed_characters: int,
20+
model_type_used: Optional[str] = None,
2021
):
2122
self.text = text
2223
self.detected_source_lang = detected_source_lang
2324
self.billed_characters = billed_characters
25+
self.model_type_used = model_type_used
2426

2527
def __str__(self):
2628
return self.text
@@ -421,3 +423,18 @@ class SplitSentences(Enum):
421423

422424
def __str__(self):
423425
return self.value
426+
427+
428+
class ModelType(Enum):
429+
"""Options for model_type parameter.
430+
431+
Sets whether the translation engine should use a newer model type that
432+
offers higher quality translations at the cost of translation time.
433+
"""
434+
435+
QUALITY_OPTIMIZED = "quality_optimized"
436+
LATENCY_OPTIMIZED = "latency_optimized"
437+
PREFER_QUALITY_OPTIMIZED = "prefer_quality_optimized"
438+
439+
def __str__(self):
440+
return self.value

deepl/translator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Formality,
99
GlossaryInfo,
1010
GlossaryLanguagePair,
11+
ModelType,
1112
Language,
1213
SplitSentences,
1314
TextResult,
@@ -347,6 +348,7 @@ def translate_text(
347348
non_splitting_tags: Union[str, List[str], None] = None,
348349
splitting_tags: Union[str, List[str], None] = None,
349350
ignore_tags: Union[str, List[str], None] = None,
351+
model_type: Union[str, ModelType, None] = None,
350352
) -> Union[TextResult, List[TextResult]]:
351353
"""Translate text(s) into the target language.
352354
@@ -387,6 +389,8 @@ def translate_text(
387389
:param ignore_tags: (Optional) XML tags containing text that should not
388390
be translated.
389391
:type ignore_tags: List of XML tags or comma-separated-list of tags.
392+
:param model_type: (Optional) Controls whether the translation engine
393+
should use a potentially slower model to achieve higher quality.
390394
:return: List of TextResult objects containing results, unless input
391395
text was one string, then a single TextResult object is returned.
392396
"""
@@ -425,6 +429,8 @@ def translate_text(
425429
request_data["tag_handling"] = tag_handling
426430
if outline_detection is not None:
427431
request_data["outline_detection"] = bool(outline_detection)
432+
if model_type is not None:
433+
request_data["model_type"] = str(model_type)
428434

429435
def join_tags(tag_argument: Union[str, Iterable[str]]) -> List[str]:
430436
if isinstance(tag_argument, str):
@@ -462,8 +468,14 @@ def join_tags(tag_argument: Union[str, Iterable[str]]) -> List[str]:
462468
else ""
463469
)
464470
billed_characters = int(translation.get("billed_characters"))
471+
model_type_used = translation.get("model_type_used")
465472
output.append(
466-
TextResult(text, detected_source_language, billed_characters)
473+
TextResult(
474+
text,
475+
detected_source_language,
476+
billed_characters,
477+
model_type_used,
478+
)
467479
)
468480

469481
return output if multi_input else output[0]

tests/test_cli.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,14 @@ def test_languages(runner):
129129

130130
def test_text(runner):
131131
result = runner.invoke(
132-
main_function, 'text --to DE "proton beam" --show-detected-source'
132+
main_function,
133+
'text --to DE "proton beam" --show-detected-source '
134+
"--show-model-type-used --model-type quality_optimized",
133135
)
134136
assert result.exit_code == 0, f"exit: {result.exit_code}\n {result.output}"
135137
assert example_text["DE"] in result.output
136138
assert "Detected source" in result.output
139+
assert "Model type used: quality_optimized" in result.output
137140

138141
# Test text options
139142
extra_options = [
@@ -155,6 +158,10 @@ def test_text(runner):
155158
"--non-splitting-tags a,b --non-splitting-tags c",
156159
"'non_splitting_tags': ['a', 'b', 'c']",
157160
),
161+
(
162+
"--model-type quality_optimized",
163+
"'model_type': 'quality_optimized'",
164+
),
158165
]
159166
for args, search_str in extra_options:
160167
result = runner.invoke(

tests/test_translate_text.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@ def test_single_text(translator):
1717
assert result.billed_characters == len(example_text["EN"])
1818

1919

20+
@pytest.mark.parametrize(
21+
"model_type",
22+
[model_type for model_type in deepl.ModelType],
23+
)
24+
def test_model_type(translator, model_type):
25+
result = translator.translate_text(
26+
example_text["EN"], target_lang="DE", model_type=model_type
27+
)
28+
# TODO: use `removeprefix()` when we only support py3.8+
29+
expected_model_type = str(model_type)
30+
prefix_to_remove = "prefer_"
31+
if expected_model_type.startswith(prefix_to_remove):
32+
expected_model_type = expected_model_type[
33+
len(prefix_to_remove) : # noqa: E203
34+
]
35+
assert expected_model_type == result.model_type_used
36+
37+
2038
def test_string_list(translator):
2139
texts = [example_text["FR"], example_text["EN"]]
2240
result = translator.translate_text(texts, target_lang="DE")

0 commit comments

Comments
 (0)