@@ -48,7 +48,14 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit,
4848 const std::string& before,
4949 const std::string& run_id) const {
5050 CTL_INF (" Listing messages for thread " + thread_id);
51- auto path = GetMessagePath (thread_id);
51+
52+ // Early validation
53+ if (limit == 0 ) {
54+ return std::vector<OpenAi::Message>();
55+ }
56+ if (!after.empty () && !before.empty () && after >= before) {
57+ return cpp::fail (" Invalid range: 'after' must be less than 'before'" );
58+ }
5259
5360 auto mutex = GrabMutex (thread_id);
5461 std::shared_lock<std::shared_mutex> lock (*mutex);
@@ -60,6 +67,11 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit,
6067
6168 std::vector<OpenAi::Message> messages = std::move (read_result.value ());
6269
70+ if (messages.empty ()) {
71+ return messages;
72+ }
73+
74+ // Filter by run_id
6375 if (!run_id.empty ()) {
6476 messages.erase (std::remove_if (messages.begin (), messages.end (),
6577 [&run_id](const OpenAi::Message& msg) {
@@ -68,52 +80,52 @@ MessageFsRepository::ListMessages(const std::string& thread_id, uint8_t limit,
6880 messages.end ());
6981 }
7082
71- std::sort (messages.begin (), messages.end (),
72- [&order](const OpenAi::Message& a, const OpenAi::Message& b) {
73- if (order == " desc" ) {
74- return a.created_at > b.created_at ;
75- }
76- return a.created_at < b.created_at ;
77- });
83+ const bool is_descending = (order == " desc" );
84+ std::sort (
85+ messages.begin (), messages.end (),
86+ [is_descending](const OpenAi::Message& a, const OpenAi::Message& b) {
87+ return is_descending ? (a.id > b.id ) : (a.id < b.id );
88+ });
7889
7990 auto start_it = messages.begin ();
8091 auto end_it = messages.end ();
8192
8293 if (!after.empty ()) {
83- start_it = std::find_if (
84- messages.begin (), messages.end (),
85- [&after](const OpenAi::Message& msg) { return msg.id == after; });
86- if (start_it != messages.end ()) {
87- ++start_it; // Start from the message after the 'after' message
88- } else {
89- start_it = messages.begin ();
94+ start_it = std::lower_bound (
95+ messages.begin (), messages.end (), after,
96+ [is_descending](const OpenAi::Message& msg, const std::string& value) {
97+ return is_descending ? (msg.id > value) : (msg.id < value);
98+ });
99+
100+ if (start_it != messages.end () && start_it->id == after) {
101+ ++start_it;
90102 }
91103 }
92104
93105 if (!before.empty ()) {
94- end_it = std::find_if (
95- messages.begin (), messages.end (),
96- [&before](const OpenAi::Message& msg) { return msg.id == before; });
106+ end_it = std::upper_bound (
107+ start_it, messages.end (), before,
108+ [is_descending](const std::string& value, const OpenAi::Message& msg) {
109+ return is_descending ? (value > msg.id ) : (value < msg.id );
110+ });
97111 }
98112
99- std::vector<OpenAi::Message> result;
100- size_t distance = std::distance (start_it, end_it);
101- size_t limit_size = static_cast <size_t >(limit);
102- CTL_INF (" Distance: " + std::to_string (distance) +
103- " , limit_size: " + std::to_string (limit_size));
104- result.reserve (distance < limit_size ? distance : limit_size);
113+ const size_t available_messages = std::distance (start_it, end_it);
114+ const size_t result_size =
115+ std::min (static_cast <size_t >(limit), available_messages);
105116
106- for (auto it = start_it; it != end_it && result.size () < limit_size; ++it) {
107- result.push_back (std::move (*it));
108- }
117+ CTL_INF (" Available messages: " + std::to_string (available_messages) +
118+ " , result size: " + std::to_string (result_size));
119+
120+ std::vector<OpenAi::Message> result;
121+ result.reserve (result_size);
122+ std::move (start_it, start_it + result_size, std::back_inserter (result));
109123
110124 return result;
111125}
112126
113127cpp::result<OpenAi::Message, std::string> MessageFsRepository::RetrieveMessage (
114128 const std::string& thread_id, const std::string& message_id) const {
115- auto path = GetMessagePath (thread_id);
116-
117129 auto mutex = GrabMutex (thread_id);
118130 std::unique_lock<std::shared_mutex> lock (*mutex);
119131
@@ -133,8 +145,6 @@ cpp::result<OpenAi::Message, std::string> MessageFsRepository::RetrieveMessage(
133145
134146cpp::result<void , std::string> MessageFsRepository::ModifyMessage (
135147 OpenAi::Message& message) {
136- auto path = GetMessagePath (message.thread_id );
137-
138148 auto mutex = GrabMutex (message.thread_id );
139149 std::unique_lock<std::shared_mutex> lock (*mutex);
140150
@@ -143,6 +153,7 @@ cpp::result<void, std::string> MessageFsRepository::ModifyMessage(
143153 return cpp::fail (messages.error ());
144154 }
145155
156+ auto path = GetMessagePath (message.thread_id );
146157 std::ofstream file (path, std::ios::trunc);
147158 if (!file) {
148159 return cpp::fail (" Failed to open file for writing: " + path.string ());
0 commit comments