Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 0b5b9aa

Browse files
authored
fix: sort messages by its ulid instead of created_at (#1778)
1 parent 9694ec8 commit 0b5b9aa

File tree

1 file changed

+42
-31
lines changed

1 file changed

+42
-31
lines changed

engine/repositories/message_fs_repository.cc

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,14 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit,
4848
const std::string& before,
4949
const std::string& run_id) const {
5050
CTL_INF("Listing messages for thread " + thread_id);
51-
auto path = GetMessagePath(thread_id);
51+
52+
// Early validation
53+
if (limit == 0) {
54+
return std::vector<OpenAi::Message>();
55+
}
56+
if (!after.empty() && !before.empty() && after >= before) {
57+
return cpp::fail("Invalid range: 'after' must be less than 'before'");
58+
}
5259

5360
auto mutex = GrabMutex(thread_id);
5461
std::shared_lock<std::shared_mutex> lock(*mutex);
@@ -60,6 +67,11 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit,
6067

6168
std::vector<OpenAi::Message> messages = std::move(read_result.value());
6269

70+
if (messages.empty()) {
71+
return messages;
72+
}
73+
74+
// Filter by run_id
6375
if (!run_id.empty()) {
6476
messages.erase(std::remove_if(messages.begin(), messages.end(),
6577
[&run_id](const OpenAi::Message& msg) {
@@ -68,52 +80,52 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit,
6880
messages.end());
6981
}
7082

71-
std::sort(messages.begin(), messages.end(),
72-
[&order](const OpenAi::Message& a, const OpenAi::Message& b) {
73-
if (order == "desc") {
74-
return a.created_at > b.created_at;
75-
}
76-
return a.created_at < b.created_at;
77-
});
83+
const bool is_descending = (order == "desc");
84+
std::sort(
85+
messages.begin(), messages.end(),
86+
[is_descending](const OpenAi::Message& a, const OpenAi::Message& b) {
87+
return is_descending ? (a.id > b.id) : (a.id < b.id);
88+
});
7889

7990
auto start_it = messages.begin();
8091
auto end_it = messages.end();
8192

8293
if (!after.empty()) {
83-
start_it = std::find_if(
84-
messages.begin(), messages.end(),
85-
[&after](const OpenAi::Message& msg) { return msg.id == after; });
86-
if (start_it != messages.end()) {
87-
++start_it; // Start from the message after the 'after' message
88-
} else {
89-
start_it = messages.begin();
94+
start_it = std::lower_bound(
95+
messages.begin(), messages.end(), after,
96+
[is_descending](const OpenAi::Message& msg, const std::string& value) {
97+
return is_descending ? (msg.id > value) : (msg.id < value);
98+
});
99+
100+
if (start_it != messages.end() && start_it->id == after) {
101+
++start_it;
90102
}
91103
}
92104

93105
if (!before.empty()) {
94-
end_it = std::find_if(
95-
messages.begin(), messages.end(),
96-
[&before](const OpenAi::Message& msg) { return msg.id == before; });
106+
end_it = std::upper_bound(
107+
start_it, messages.end(), before,
108+
[is_descending](const std::string& value, const OpenAi::Message& msg) {
109+
return is_descending ? (value > msg.id) : (value < msg.id);
110+
});
97111
}
98112

99-
std::vector<OpenAi::Message> result;
100-
size_t distance = std::distance(start_it, end_it);
101-
size_t limit_size = static_cast<size_t>(limit);
102-
CTL_INF("Distance: " + std::to_string(distance) +
103-
", limit_size: " + std::to_string(limit_size));
104-
result.reserve(distance < limit_size ? distance : limit_size);
113+
const size_t available_messages = std::distance(start_it, end_it);
114+
const size_t result_size =
115+
std::min(static_cast<size_t>(limit), available_messages);
105116

106-
for (auto it = start_it; it != end_it && result.size() < limit_size; ++it) {
107-
result.push_back(std::move(*it));
108-
}
117+
CTL_INF("Available messages: " + std::to_string(available_messages) +
118+
", result size: " + std::to_string(result_size));
119+
120+
std::vector<OpenAi::Message> result;
121+
result.reserve(result_size);
122+
std::move(start_it, start_it + result_size, std::back_inserter(result));
109123

110124
return result;
111125
}
112126

113127
cpp::result<OpenAi::Message, std::string> MessageFsRepository::RetrieveMessage(
114128
const std::string& thread_id, const std::string& message_id) const {
115-
auto path = GetMessagePath(thread_id);
116-
117129
auto mutex = GrabMutex(thread_id);
118130
std::unique_lock<std::shared_mutex> lock(*mutex);
119131

@@ -133,8 +145,6 @@ cpp::result<OpenAi::Message, std::string> MessageFsRepository::RetrieveMessage(
133145

134146
cpp::result<void, std::string> MessageFsRepository::ModifyMessage(
135147
OpenAi::Message& message) {
136-
auto path = GetMessagePath(message.thread_id);
137-
138148
auto mutex = GrabMutex(message.thread_id);
139149
std::unique_lock<std::shared_mutex> lock(*mutex);
140150

@@ -143,6 +153,7 @@ cpp::result<void, std::string> MessageFsRepository::ModifyMessage(
143153
return cpp::fail(messages.error());
144154
}
145155

156+
auto path = GetMessagePath(message.thread_id);
146157
std::ofstream file(path, std::ios::trunc);
147158
if (!file) {
148159
return cpp::fail("Failed to open file for writing: " + path.string());

0 commit comments

Comments
 (0)