Skip to content

Commit c8fd213

Browse files
committed
feat: implement ChatWithTrait for middleware support
- Add ChatWithTrait struct that uses HttpClient trait - Implement chat() method on ClientWithTrait - Enable HTTP middleware to intercept all OpenAI API calls - Support automatic tracing at HTTP level
1 parent d3cd3cc commit c8fd213

File tree

3 files changed

+138
-0
lines changed

3 files changed

+138
-0
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Chat API implementation that works with ClientWithTrait and HttpClient trait
2+
3+
use crate::{
4+
config::Config,
5+
error::{ApiError, OpenAIError},
6+
http_client::{HttpClient, HttpResponse},
7+
types::{
8+
CreateChatCompletionRequest, CreateChatCompletionResponse,
9+
CreateChatCompletionStreamResponse,
10+
},
11+
ClientWithTrait,
12+
};
13+
use std::pin::Pin;
14+
use futures::Stream;
15+
use bytes::Bytes;
16+
use reqwest::{Method, header::{HeaderMap, HeaderValue, CONTENT_TYPE}};
17+
18+
/// Chat API group with HttpClient trait support
19+
pub struct ChatWithTrait<'c, C: Config> {
20+
client: &'c ClientWithTrait<C>,
21+
}
22+
23+
impl<'c, C: Config> ChatWithTrait<'c, C> {
24+
pub fn new(client: &'c ClientWithTrait<C>) -> Self {
25+
Self { client }
26+
}
27+
28+
/// Creates a model response for the given chat conversation.
29+
pub async fn create(
30+
&self,
31+
request: CreateChatCompletionRequest,
32+
) -> Result<CreateChatCompletionResponse, OpenAIError> {
33+
// Prepare the request
34+
let url = self.client.config.url("/chat/completions");
35+
36+
// Serialize request body
37+
let body = serde_json::to_vec(&request)
38+
.map_err(|e| OpenAIError::JSONSerialize(e))?;
39+
40+
// Prepare headers
41+
let mut headers = self.client.config.headers();
42+
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
43+
44+
// Execute request with backoff
45+
let response = self.execute_with_backoff(
46+
Method::POST,
47+
url,
48+
headers,
49+
Some(Bytes::from(body)),
50+
).await?;
51+
52+
// Parse response
53+
serde_json::from_slice(&response.body)
54+
.map_err(|e| OpenAIError::JSONDeserialize(e))
55+
}
56+
57+
/// Creates a streaming response for the given chat conversation.
58+
pub async fn create_stream(
59+
&self,
60+
mut request: CreateChatCompletionRequest,
61+
) -> Result<
62+
Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>>,
63+
OpenAIError,
64+
> {
65+
// Set stream flag
66+
request.stream = Some(true);
67+
68+
// For now, return an error as streaming requires more complex implementation
69+
// This would need to handle SSE (Server-Sent Events) parsing
70+
Err(OpenAIError::InvalidArgument(
71+
"Streaming not yet implemented for ChatWithTrait".to_string()
72+
))
73+
}
74+
75+
/// Execute request with exponential backoff for rate limiting
76+
async fn execute_with_backoff(
77+
&self,
78+
method: Method,
79+
url: reqwest::Url,
80+
headers: HeaderMap,
81+
body: Option<Bytes>,
82+
) -> Result<HttpResponse, OpenAIError> {
83+
use backoff::{future::retry, ExponentialBackoff};
84+
85+
let http_client = self.client.http_client.clone();
86+
let backoff = self.client.backoff.clone();
87+
88+
retry(backoff, || async {
89+
let result = http_client
90+
.request(method.clone(), url.clone(), headers.clone(), body.clone())
91+
.await;
92+
93+
match result {
94+
Ok(response) => {
95+
if response.status.is_success() {
96+
Ok(response)
97+
} else if response.status.as_u16() == 429 {
98+
// Rate limited, retry with backoff
99+
Err(backoff::Error::transient(OpenAIError::ApiError(
100+
ApiError {
101+
message: "Rate limited".to_string(),
102+
r#type: Some("rate_limit_exceeded".to_string()),
103+
param: None,
104+
code: None,
105+
}
106+
)))
107+
} else {
108+
// Other error, don't retry
109+
let api_error = serde_json::from_slice(&response.body)
110+
.unwrap_or_else(|_| ApiError {
111+
message: format!("HTTP {}", response.status),
112+
r#type: None,
113+
param: None,
114+
code: None,
115+
});
116+
Err(backoff::Error::permanent(OpenAIError::ApiError(api_error)))
117+
}
118+
}
119+
Err(e) => {
120+
// Network error, retry
121+
Err(backoff::Error::transient(OpenAIError::Reqwest(
122+
reqwest::Error::from(std::io::Error::new(
123+
std::io::ErrorKind::Other,
124+
e.message
125+
))
126+
)))
127+
}
128+
}
129+
})
130+
.await
131+
}
132+
}

async-openai/src/client.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,4 +589,9 @@ impl<C: Config> ClientWithTrait<C> {
589589
pub fn config(&self) -> &C {
590590
&self.config
591591
}
592+
593+
/// To call [ChatWithTrait] group related APIs using this client.
594+
pub fn chat(&self) -> crate::chat_with_trait::ChatWithTrait<C> {
595+
crate::chat_with_trait::ChatWithTrait::new(self)
596+
}
592597
}

async-openai/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ mod audio;
145145
mod audit_logs;
146146
mod batches;
147147
mod chat;
148+
mod chat_with_trait;
148149
mod client;
149150
mod completion;
150151
pub mod config;

0 commit comments

Comments
 (0)