From f8f198aefde7a7345a34edfc772953280b867c0a Mon Sep 17 00:00:00 2001 From: Eugene Ostroukhov Date: Sun, 20 Apr 2025 16:52:53 -0700 Subject: [PATCH] Refactor to do event queue and have a chat history --- src/BUILD.bazel | 26 +++++++++++ src/anthropic.cc | 66 ++++++++++++-------------- src/chat.cc | 42 +++++++++++++++++ src/chat.h | 90 +++++++++++++++++++++++++++++++++++ src/event_loop.cc | 41 ++++++++++++++++ src/event_loop.h | 46 ++++++++++++++++++ src/main.cc | 63 ++++++++++++++++++++----- src/model.cc | 27 +++++++++++ src/model.h | 30 +++++++++--- src/openai.cc | 50 +++++++++----------- test/BUILD.bazel | 25 ++++++++++ test/chat.test.cc | 101 ++++++++++++++++++++++++++++++++++++++++ test/event_loop.test.cc | 39 ++++++++++++++++ 13 files changed, 565 insertions(+), 81 deletions(-) create mode 100644 src/chat.cc create mode 100644 src/chat.h create mode 100644 src/event_loop.cc create mode 100644 src/event_loop.h create mode 100644 src/model.cc create mode 100644 test/chat.test.cc create mode 100644 test/event_loop.test.cc diff --git a/src/BUILD.bazel b/src/BUILD.bazel index be71e38..02c609d 100644 --- a/src/BUILD.bazel +++ b/src/BUILD.bazel @@ -7,6 +7,7 @@ cc_binary( srcs = ["main.cc"], visibility = ["//visibility:public"], deps = [ + ":chat", ":fetch", ":llms", ":tui", @@ -21,6 +22,17 @@ cc_binary( ], ) +cc_library( + name = "chat", + srcs = ["chat.cc"], + hdrs = ["chat.h"], + deps = [ + ":event_loop", + "@abseil-cpp//absl/functional:any_invocable", + "@abseil-cpp//absl/synchronization", + ], +) + cc_library( name = "fetch", srcs = ["fetch.cc"], @@ -36,6 +48,17 @@ cc_library( ], ) +cc_library(name="event_loop", + srcs = ["event_loop.cc"], + hdrs = ["event_loop.h"], + visibility = ["//visibility:public"], + deps = [ + "@abseil-cpp//absl/functional:any_invocable", + "@abseil-cpp//absl/synchronization", + "@abseil-cpp//absl/log", + ], +) + cc_library( name = "json_decode", srcs = ["json_decode.cc"], @@ -52,6 +75,7 @@ cc_library( name = "llms", srcs = [ "anthropic.cc", + "model.cc", "openai.cc", ], hdrs = [ @@ -61,9 +85,11 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":chat", ":fetch", ":json_decode", "@abseil-cpp//absl/flags:flag", + "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", "@nlohmann_json//:json", diff --git a/src/anthropic.cc b/src/anthropic.cc index 0cfc29f..39be2f7 100644 --- a/src/anthropic.cc +++ b/src/anthropic.cc @@ -5,15 +5,18 @@ #include #include #include +#include +#include #include #include "absl/flags/flag.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "nlohmann/json.hpp" +#include "src/chat.h" #include "src/fetch.h" #include "src/json_decode.h" #include "src/model.h" @@ -27,47 +30,39 @@ namespace { class AnthropicModel : public Model { public: - AnthropicModel(std::string_view model, - std::string_view api_key, int max_tokens) - : model_(model), + AnthropicModel(std::string model, std::string_view api_key, int max_tokens, + std::shared_ptr fetch) + : Model(std::move(model)), api_key_(api_key), - max_tokens_(max_tokens) {} + max_tokens_(max_tokens), + fetch_(std::move(fetch)) {} ~AnthropicModel() override = default; - std::string_view name() const override { return model_; } - - absl::StatusOr Prompt( - const Fetch& fetch, std::string_view prompt, - absl::Span input_contents) override; - private: - std::string model_; + absl::StatusOr Send(const Message& message) override; + std::string api_key_; int max_tokens_; + std::shared_ptr fetch_; + std::unordered_map> subscriptions_; }; -absl::StatusOr AnthropicModel::Prompt( - const Fetch& fetch, std::string_view prompt, - absl::Span input_contents) { - std::string combined_input = absl::StrJoin(input_contents, "\n\n"); - +absl::StatusOr AnthropicModel::Send(const Message& message) { nlohmann::json request = { - {"model", model_}, + {"model", name()}, {"max_tokens", max_tokens_}, - {"messages", - nlohmann::json::array( - {{{"role", "user"}, - {"content", absl::StrCat(prompt, "\n\n", combined_input)}}})}, + {"messages", nlohmann::json::array( + {{{"role", "user"}, {"content", message.content()}}})}, }; auto response = - fetch.Post("https://api.anthropic.com/v1/messages", - { - {.key = "Content-Type", .value = "application/json"}, - {.key = "x-api-key", .value = api_key_}, - {.key = "anthropic-version", .value = "2023-06-01"}, - }, - request); + fetch_->Post("https://api.anthropic.com/v1/messages", + { + {.key = "Content-Type", .value = "application/json"}, + {.key = "x-api-key", .value = api_key_}, + {.key = "anthropic-version", .value = "2023-06-01"}, + }, + request); if (!response.ok()) { return std::move(response).status(); @@ -83,13 +78,12 @@ absl::StatusOr AnthropicModel::Prompt( (*json_response)["error"].dump(2))); } - auto message = - json::JsonDecode(*json_response)["content"][0]["text"].String(); - if (!message.ok()) { + auto res = json::JsonDecode(*json_response)["content"][0]["text"].String(); + if (!res.ok()) { return absl::InternalError( - absl::StrCat("Anthropic API error: ", message.error())); + absl::StrCat("Anthropic API error: ", res.error())); } - return message.value(); + return res.value(); } class AnthropicModelProvider : public ModelProvider { @@ -106,8 +100,8 @@ class AnthropicModelProvider : public ModelProvider { if (!api_key.has_value()) { return absl::InvalidArgumentError("Anthropic API key is required"); } - auto client = std::make_unique(model, *api_key, - parameters_.max_tokens()); + auto client = std::make_unique( + std::string(model), *api_key, parameters_.max_tokens(), fetch_); return ModelHandle(std::move(client)); } diff --git a/src/chat.cc b/src/chat.cc new file mode 100644 index 0000000..7069c9d --- /dev/null +++ b/src/chat.cc @@ -0,0 +1,42 @@ +#include "src/chat.h" + +#include + +namespace uchen::chat { + +std::optional Chat::FindMessage(int id) const { + absl::MutexLock lock(&message_mutex_); + auto it = messages_.find(id); + if (it != messages_.end()) { + return it->second; + } + return std::nullopt; +} + +Message Chat::SendMessage(Message::Origin origin, std::string content, + std::optional parent_id, void* provider) { + absl::MutexLock lock(&message_mutex_); + int id = next_id_++; + auto result = messages_.emplace( + id, Message(id, origin, std::move(content), parent_id, provider)); + event_loop_->Run( + [chat = shared_from_this(), message = result.first->second]() { + absl::MutexLock lock(&chat->callback_mutex_); + for (const auto& [_, callback] : chat->callbacks_) { + callback(message); + } + }); + return result.first->second; +} + +std::unique_ptr Chat::Subscribe(Callback callback) { + size_t key = next_id_++; + event_loop_->Run([chat = shared_from_this(), key, + callback = std::move(callback)]() mutable { + absl::MutexLock lock(&chat->callback_mutex_); + chat->callbacks_.emplace(key, std::move(callback)); + }); + return std::make_unique(this, key); +} + +} // namespace uchen::chat \ No newline at end of file diff --git a/src/chat.h b/src/chat.h new file mode 100644 index 0000000..1d42adb --- /dev/null +++ b/src/chat.h @@ -0,0 +1,90 @@ +#ifndef SRC_CHAT_H_ +#define SRC_CHAT_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/synchronization/mutex.h" + +#include "src/event_loop.h" + +namespace uchen::chat { + +class Message { + public: + enum class Origin { kAssistant, kSystem, kUser }; + + Message() = default; + Message(int id, Origin origin, std::string content, + std::optional parent_id, void* provider) + : id_(id), + origin_(origin), + content_(std::move(content)), + parent_id_(parent_id), + provider_(provider) {} + + Message(const Message&) = default; + Message& operator=(const Message&) = default; + Message(Message&&) = default; + Message& operator=(Message&&) = default; + + bool operator==(const Message& other) const = default; + + int id() const { return id_; } + Origin origin() const { return origin_; } + const std::string& content() const { return content_; } + std::optional parent_id() const { return parent_id_; } + void* provider() const { return provider_; } + + private: + int id_; + Origin origin_; + std::string content_; + std::optional parent_id_; + void* provider_; +}; + +class Chat : public std::enable_shared_from_this { + public: + using Callback = absl::AnyInvocable; + + class Unsubscribe { + public: + Unsubscribe(Chat* chat, size_t id) : chat_(chat), id_(id) {} + + ~Unsubscribe() { chat_->callbacks_.erase(id_); } + + private: + Chat* chat_; + size_t id_; + }; + + static std::shared_ptr Create(std::shared_ptr event_loop) { + // Can't use std::make_shared because ctor is private. + return std::shared_ptr(new Chat(std::move(event_loop))); + } + + std::optional FindMessage(int id) const; + Message SendMessage(Message::Origin origin, std::string content, + std::optional parent_id, void* provider); + std::unique_ptr Subscribe(Callback callback); + + private: + explicit Chat(std::shared_ptr event_loop) + : event_loop_(std::move(event_loop)) {} + mutable absl::Mutex message_mutex_; + mutable absl::Mutex callback_mutex_; + std::shared_ptr event_loop_; + std::atomic_int next_id_{1}; + std::unordered_map callbacks_ + ABSL_GUARDED_BY(&callback_mutex_); + std::unordered_map messages_ + ABSL_GUARDED_BY(&message_mutex_); +}; + +} // namespace uchen::chat + +#endif // SRC_CHAT_H_ \ No newline at end of file diff --git a/src/event_loop.cc b/src/event_loop.cc new file mode 100644 index 0000000..8402734 --- /dev/null +++ b/src/event_loop.cc @@ -0,0 +1,41 @@ +#include "src/event_loop.h" + +#include +#include + +namespace uchen::chat { + +void EventLoop::Run(absl::AnyInvocable task) { + absl::MutexLock lock(&mutex_); + tasks_.push_back(std::move(task)); +} + +void EventLoop::Loop(EventLoop* event_loop) { + while (true) { + auto done_tasks = event_loop->GetTasks(); + if (std::holds_alternative(done_tasks)) { + break; + } + for (auto& task : std::get(done_tasks)) { + task(); + } + } +} + +std::variant>> +EventLoop::GetTasks() { + absl::MutexLock lock(&mutex_); + absl::Condition condition( + +[](EventLoop* event_loop) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(event_loop->mutex_) { + return !event_loop->tasks_.empty() || event_loop->stop_; + }, + this); + mutex_.Await(condition); + if (stop_) { + return true; + } + return std::move(tasks_); +} + +} // namespace uchen::chat \ No newline at end of file diff --git a/src/event_loop.h b/src/event_loop.h new file mode 100644 index 0000000..8587cbb --- /dev/null +++ b/src/event_loop.h @@ -0,0 +1,46 @@ +#ifndef SRC_EVENT_LOOP_H +#define SRC_EVENT_LOOP_H + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/log/log.h" // IWYU pragma: keep +#include "absl/synchronization/mutex.h" + +namespace uchen::chat { + +class EventLoop { + public: + static std::shared_ptr Create() { + return std::make_shared(); + } + + EventLoop() : thread_(&EventLoop::Loop, this) {} + + ~EventLoop() { + { + absl::MutexLock lock(&mutex_); + stop_ = true; + } + thread_.join(); + } + + void Run(absl::AnyInvocable task); + + private: + using TasksList = std::vector>; + static void Loop(EventLoop* event_loop); + + std::variant GetTasks(); + + absl::Mutex mutex_; + TasksList tasks_ ABSL_GUARDED_BY(&mutex_); + bool stop_ ABSL_GUARDED_BY(&mutex_) = false; + std::thread thread_; +}; + +} // namespace uchen::chat + +#endif // SRC_EVENT_LOOP_H \ No newline at end of file diff --git a/src/main.cc b/src/main.cc index 5835625..ef489d8 100644 --- a/src/main.cc +++ b/src/main.cc @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -16,6 +17,8 @@ #include "curl/curl.h" #include "src/anthropic.h" +#include "src/chat.h" +#include "src/event_loop.h" #include "src/input.h" #include "src/model.h" #include "src/openai.h" @@ -24,10 +27,47 @@ namespace uchen::chat { namespace { -int Chat(Model* model) { +constexpr absl::Duration kTimeout = absl::Seconds(5); + +class MessageLog { + public: + Message WaitForMessage() { + absl::MutexLock lock(&mutex_); + mutex_.AwaitWithTimeout({+[](const std::vector* messages) { + return !messages->empty(); + }, + &messages_}, + kTimeout); + Message message = std::move(messages_.back()); + messages_.pop_back(); + return message; + } + + void AddMessage(Message message) { + if (message.provider() == this) { + return; + } + absl::MutexLock lock(&mutex_); + messages_.push_back(std::move(message)); + } + + private: + absl::Mutex mutex_; + std::vector messages_ ABSL_GUARDED_BY(&mutex_); +}; + +int ChatLoop(Model* model) { std::cout << absl::Substitute("Model: $0\nType your message below:", model->name()); - uchen::chat::InputReader reader(std::cin); + auto event_loop = EventLoop::Create(); + auto chat = Chat::Create(event_loop); + auto model_unsubscribe = model->Connect(chat); + InputReader reader(std::cin); + CurlFetch fetch; + MessageLog message_log; + auto subscription = chat->Subscribe( + [&](Message message) { message_log.AddMessage(std::move(message)); }); + while (true) { std::cout << "\n> "; auto prompt = reader(); @@ -35,14 +75,14 @@ int Chat(Model* model) { return 0; } if (!prompt->empty()) { - uchen::chat::CurlFetch fetch; - auto response = - SpinWhile([&]() { return model->Prompt(fetch, *prompt, {}); }); - if (!response.ok()) { - std::cerr << "Error: " << response.status().message() << std::endl; - return 1; - } - std::cout << *response << std::endl; + chat->SendMessage(Message::Origin::kUser, *prompt, std::nullopt, + &message_log); + auto response = SpinWhile([&]() { return message_log.WaitForMessage(); }); + // if (!response.ok()) { + // std::cerr << "Error: " << response.status().message() << std::endl; + // return 1; + // } + std::cout << response.content() << std::endl; } } } @@ -73,7 +113,6 @@ int main(int argc, char* argv[], char* envp[]) { uchen::chat::MakeOpenAIModelProvider(fetch, parameters), uchen::chat::MakeAnthropicModelProvider(fetch, parameters), }; - if (absl::GetFlag(FLAGS_list)) { for (const auto& provider : providers) { auto models = provider->ListModels(); @@ -102,6 +141,6 @@ int main(int argc, char* argv[], char* envp[]) { return 1; } CHECK_NE(model->get(), nullptr); - return uchen::chat::Chat(model->get()); + return uchen::chat::ChatLoop(model->get()); } } \ No newline at end of file diff --git a/src/model.cc b/src/model.cc new file mode 100644 index 0000000..cc2a7dc --- /dev/null +++ b/src/model.cc @@ -0,0 +1,27 @@ +#include "src/model.h" + +#include "absl/log/check.h" +#include "absl/log/log.h" + +namespace uchen::chat { + +// Connects the model to a chat session +std::unique_ptr Model::Connect(std::shared_ptr chat) { + CHECK_EQ(subscriptions_.count(chat.get()), 0); + subscriptions_.emplace( + chat.get(), chat->Subscribe([this, chat = std::move(chat)](const Message& message) { + if (message.provider() == this) { + return; + } + auto response = Send(message); + if (!response.ok()) { + LOG(ERROR) << "Error sending message: " << response.status(); + return; + } + chat->SendMessage(Message::Origin::kAssistant, response.value(), + message.id(), this); + })); + return std::make_unique(this, chat.get()); +} + +} // namespace uchen::chat \ No newline at end of file diff --git a/src/model.h b/src/model.h index 7874c23..efe9838 100644 --- a/src/model.h +++ b/src/model.h @@ -2,27 +2,45 @@ #define SRC_MODEL_H_ #include +#include #include #include #include +#include #include "absl/status/statusor.h" -#include "src/fetch.h" +#include "src/chat.h" namespace uchen::chat { // Interface for LLM clients class Model { public: + class Unsubscribe { + public: + Unsubscribe(Model* model, Chat* chat) : model_(model), chat_(chat) {} + + ~Unsubscribe() { model_->subscriptions_.erase(chat_); } + + private: + Model* model_; + Chat* chat_; + }; + + explicit Model(std::string name) : name_(std::move(name)) {} virtual ~Model() = default; - virtual std::string_view name() const = 0; + std::string_view name() const { return name_; } + + // Connects the model to a chat session + std::unique_ptr Connect(std::shared_ptr chat); + + protected: + virtual absl::StatusOr Send(const Message& message) = 0; - // Queries the LLM with a prompt and multiple input contents - virtual absl::StatusOr Prompt( - const Fetch& fetch, std::string_view prompt, - absl::Span input_contents) = 0; + std::string name_; + std::unordered_map> subscriptions_; }; class Parameters { diff --git a/src/openai.cc b/src/openai.cc index 86ef598..cd1d6b0 100644 --- a/src/openai.cc +++ b/src/openai.cc @@ -6,16 +6,18 @@ #include #include #include +#include #include #include "absl/flags/flag.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/substitute.h" #include "nlohmann/json.hpp" +#include "src/chat.h" #include "src/fetch.h" #include "src/json_decode.h" #include "src/model.h" @@ -29,40 +31,35 @@ namespace { class OpenAIModel : public Model { public: - explicit OpenAIModel(std::string_view model, std::string_view api_key, - int max_tokens) - : model_(model), api_key_(api_key), max_tokens_(max_tokens) {} + explicit OpenAIModel(std::string model, std::string_view api_key, + int max_tokens, std::shared_ptr fetch) + : Model(std::move(model)), + api_key_(api_key), + max_tokens_(max_tokens), + fetch_(std::move(fetch)) {} ~OpenAIModel() override = default; - std::string_view name() const override { return model_; } - - absl::StatusOr Prompt( - const Fetch& fetch, std::string_view prompt, - absl::Span input_contents) override; - private: - std::string model_; + absl::StatusOr Send(const Message& message) override; + std::string api_key_; int max_tokens_; + std::shared_ptr fetch_; + std::unordered_map> subscriptions_; }; -absl::StatusOr OpenAIModel::Prompt( - const Fetch& fetch, std::string_view prompt, - absl::Span input_contents) { - std::string combined_input = absl::StrJoin(input_contents, "\n\n"); - - auto response = fetch.Post( +absl::StatusOr OpenAIModel::Send(const Message& message) { + auto response = fetch_->Post( "https://api.openai.com/v1/chat/completions", { {.key = "Content-Type", .value = "application/json"}, {.key = "Authorization", .value = absl::StrCat("Bearer ", api_key_)}, }, - {{"model", model_}, + {{"model", name()}, {"max_tokens", max_tokens_}, {"messages", nlohmann::json::array( - {{{"role", "user"}, - {"content", absl::StrCat(prompt, "\n\n", combined_input)}}})}}); + {{{"role", "user"}, {"content", message.content()}}})}}); if (!response.ok()) { return std::move(response).status(); @@ -82,14 +79,13 @@ absl::StatusOr OpenAIModel::Prompt( absl::StrCat("OpenAI API error: ", error_message)); } - auto message = + auto res = json::JsonDecode(*json_response)["choices"][0]["message"]["content"] .String(); - if (!message.ok()) { - return absl::InternalError( - absl::StrCat("OpenAI API error: ", message.error())); + if (!res.ok()) { + return absl::InternalError(absl::StrCat("OpenAI API error: ", res.error())); } - return message.value(); + return res.value(); } class OpenAIModelProvider : public ModelProvider { @@ -110,8 +106,8 @@ class OpenAIModelProvider : public ModelProvider { if (!api_key.has_value()) { return absl::InvalidArgumentError("API key is required"); } - auto client = std::make_unique(model, *api_key, - parameters_.max_tokens()); + auto client = std::make_unique( + std::string(model), *api_key, parameters_.max_tokens(), fetch_); return ModelHandle(std::move(client)); } diff --git a/test/BUILD.bazel b/test/BUILD.bazel index a01f957..c12751a 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -1,5 +1,30 @@ load("@rules_cc//cc:defs.bzl", "cc_test") +cc_test(name="chat_test", + srcs=["chat.test.cc"], + deps=[ + "//src:chat", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:globals", + "@abseil-cpp//absl/log:initialize", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) + +cc_test( + name = "event_loop_test", + srcs = ["event_loop.test.cc"], + deps = [ + "//src:event_loop", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:globals", + "@abseil-cpp//absl/log:initialize", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) + cc_test( name = "input_test", srcs = ["input.test.cc"], diff --git a/test/chat.test.cc b/test/chat.test.cc new file mode 100644 index 0000000..9930a59 --- /dev/null +++ b/test/chat.test.cc @@ -0,0 +1,101 @@ +#include "src/chat.h" + +#include +#include + +#include + +#include "absl/log/globals.h" +#include "absl/log/initialize.h" +#include "absl/synchronization/notification.h" + +#include "gmock/gmock.h" +#include "src/event_loop.h" + +namespace uchen::chat { +namespace { + +using Origin = Message::Origin; + +TEST(ChatTest, SubscribeUnsubscribe) { + auto chat = Chat::Create(EventLoop::Create()); + using Record = std::tuple>; + absl::Mutex mutex; + absl::MutexLock lock(&mutex); + std::vector messages1; + std::vector messages2; + + auto subscribe1 = chat->Subscribe([&](const Message& message) { + absl::MutexLock lock(&mutex); + messages1.emplace_back(message.origin(), message.content(), + message.parent_id()); + }); + auto subscribe2 = chat->Subscribe([&](const Message& message) { + absl::MutexLock lock(&mutex); + messages2.emplace_back(message.origin(), message.content(), + message.parent_id()); + }); + Message msg1 = chat->SendMessage(Origin::kUser, "1", std::nullopt, nullptr); + Message msg2 = chat->SendMessage(Origin::kSystem, "2", msg1.id(), nullptr); + mutex.AwaitWithTimeout( + {+[](std::vector* values) { return values->size() == 2; }, + &messages1}, + absl::Milliseconds(50)); + subscribe1.reset(); + chat->SendMessage(Origin::kSystem, "3", std::nullopt, nullptr); + + mutex.AwaitWithTimeout( + {+[](std::vector* values) { return values->size() == 3; }, + &messages2}, + absl::Milliseconds(50)); + + EXPECT_THAT(messages1, ::testing::ElementsAre( + std::tuple(Origin::kUser, "1", std::nullopt), + std::tuple(Origin::kSystem, "2", msg1.id()))); + EXPECT_THAT(messages2, ::testing::ElementsAre( + std::tuple(Origin::kUser, "1", std::nullopt), + std::tuple(Origin::kSystem, "2", msg1.id()), + std::tuple(Origin::kSystem, "3", std::nullopt))); +} + +TEST(ChatTest, RespondToMessage) { + auto chat = Chat::Create(EventLoop::Create()); + absl::Notification notification; + int key = -1; + auto subscribe = chat->Subscribe( + [chat, ¬ification, key = &key](const Message& message) { + if (message.parent_id() == std::nullopt) { + *key = chat->SendMessage(Origin::kAssistant, message.content() + ".1", + message.id(), nullptr) + .id(); + } else { + notification.Notify(); + } + }); + Message msg1 = chat->SendMessage(Origin::kUser, "1", std::nullopt, nullptr); + notification.WaitForNotificationWithTimeout(absl::Milliseconds(20)); + EXPECT_NE(key, -1); + EXPECT_EQ(chat->FindMessage(key)->content(), "1.1"); +} + +TEST(ChatTest, FindMessage) { + auto chat = Chat::Create(EventLoop::Create()); + Message msg1 = chat->SendMessage(Origin::kSystem, "1", std::nullopt, nullptr); + Message msg2 = chat->SendMessage(Origin::kSystem, "2", msg1.id(), nullptr); + Message msg3 = chat->SendMessage(Origin::kUser, "3", msg2.id(), nullptr); + + EXPECT_EQ(chat->FindMessage(msg1.id()), msg1); + EXPECT_EQ(chat->FindMessage(msg2.id()), msg2); + EXPECT_EQ(chat->FindMessage(999), std::nullopt); +} + +} // namespace +} // namespace uchen::chat + +int main(int argc, char** argv) { + absl::InitializeLog(); + absl::SetStderrThreshold(absl::LogSeverity::kInfo); + absl::SetMinLogLevel(absl::LogSeverityAtLeast::kInfo); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/test/event_loop.test.cc b/test/event_loop.test.cc new file mode 100644 index 0000000..c358a7b --- /dev/null +++ b/test/event_loop.test.cc @@ -0,0 +1,39 @@ + +#include "src/event_loop.h" +#include + +#include "absl/log/globals.h" +#include "absl/log/initialize.h" +#include "absl/synchronization/notification.h" + +namespace uchen::chat { + +TEST(EventLoopTest, Basic) { + EventLoop loop; + absl::Notification notification; + loop.Run([&]() { + notification.Notify(); + }); + notification.WaitForNotificationWithTimeout(absl::Milliseconds(50)); + EXPECT_TRUE(notification.HasBeenNotified()); +} + +TEST(EventLoopTest, TaskCanScheduleATask) { + EventLoop loop; + absl::Notification notification; + loop.Run([&]() { + loop.Run([&]() { notification.Notify(); }); + }); + notification.WaitForNotificationWithTimeout(absl::Milliseconds(50)); + EXPECT_TRUE(notification.HasBeenNotified()); +} + +} + +int main(int argc, char** argv) { + absl::InitializeLog(); + absl::SetStderrThreshold(absl::LogSeverity::kInfo); + absl::SetMinLogLevel(absl::LogSeverityAtLeast::kInfo); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file