diff --git a/include/http/DownloadCallback.h b/include/http/DownloadCallback.h new file mode 100644 index 00000000..7066d189 --- /dev/null +++ b/include/http/DownloadCallback.h @@ -0,0 +1,8 @@ +#pragma once + +#include +#include + +namespace OpenShock::HTTP { + using DownloadCallback = std::function; +} diff --git a/include/http/HTTPClient.h b/include/http/HTTPClient.h new file mode 100644 index 00000000..77388283 --- /dev/null +++ b/include/http/HTTPClient.h @@ -0,0 +1,53 @@ +#pragma once + +#include "Common.h" +#include "http/HTTPClientState.h" +#include "http/HTTPResponse.h" +#include "http/JsonResponse.h" +#include "RateLimiter.h" + +#include + +#include +#include + +namespace OpenShock::HTTP { + class HTTPClient { + DISABLE_COPY(HTTPClient); + DISABLE_MOVE(HTTPClient); + + public: + HTTPClient(const char* url, uint32_t timeoutMs = 10'000) + : m_state(std::make_shared(url, timeoutMs)) + { + } + + inline esp_err_t SetUrl(const char* url) { + return m_state->SetUrl(url); + } + + inline esp_err_t SetHeader(const char* key, const char* value) { + return m_state->SetHeader(key, value); + } + + inline HTTPResponse Get() { + auto response = m_state->StartRequest(HTTP_METHOD_GET, 0); + if (response.error != HTTPError::None) return HTTP::HTTPResponse(response.error, response.retryAfterSeconds); + + return HTTP::HTTPResponse(m_state, response.statusCode, response.contentLength, std::move(response.headers)); + } + template + inline JsonResponse GetJson(JsonParserFn jsonParser) { + auto response = m_state->StartRequest(HTTP_METHOD_GET, 0); + if (response.error != HTTPError::None) return HTTP::JsonResponse(response.error, response.retryAfterSeconds); + + return HTTP::JsonResponse(m_state, jsonParser, response.statusCode, response.contentLength, std::move(response.headers)); + } + + inline esp_err_t Close() { + return m_state->Close(); + } + private: + std::shared_ptr m_state; + }; +} // namespace OpenShock::HTTP diff --git a/include/http/HTTPClientState.h b/include/http/HTTPClientState.h new file mode 100644 index 00000000..f1f79c5d --- /dev/null +++ b/include/http/HTTPClientState.h @@ -0,0 +1,86 @@ +#pragma once + +#include "Common.h" +#include "http/DownloadCallback.h" +#include "http/HTTPError.h" +#include "http/JsonParserFn.h" +#include "http/ReadResult.h" + +#include + +#include + +#include +#include +#include + +namespace OpenShock::HTTP { + class HTTPClientState { + DISABLE_COPY(HTTPClientState); + DISABLE_MOVE(HTTPClientState); + public: + HTTPClientState(const char* url, uint32_t timeoutMs); + ~HTTPClientState(); + + esp_err_t SetUrl(const char* url); + + esp_err_t SetHeader(const char* key, const char* value); + + struct HeaderEntry { + std::string key; + std::string value; + }; + + struct [[nodiscard]] StartRequestResult { + HTTPError error{}; + uint32_t retryAfterSeconds{}; + uint16_t statusCode{}; + bool isChunked{}; + uint32_t contentLength{}; + std::map headers{}; + }; + + StartRequestResult StartRequest(esp_http_client_method_t method, int writeLen); + + // High-throughput streaming logic + ReadResult ReadStreamImpl(DownloadCallback cb); + + ReadResult ReadStringImpl(uint32_t reserve); + + template + inline ReadResult ReadJsonImpl(uint32_t reserve, JsonParserFn jsonParser) + { + auto response = ReadStringImpl(reserve); + if (response.error != HTTPError::None) { + return response.error; + } + + cJSON* json = cJSON_ParseWithLength(response.data.c_str(), response.data.length()); + if (json == nullptr) { + return HTTPError::ParseFailed; + } + + T data; + if (!jsonParser(json, data)) { + return HTTPError::ParseFailed; + } + + cJSON_Delete(json); + + return data; + } + + inline esp_err_t Close() { + if (m_handle == nullptr) return ESP_FAIL; + return esp_http_client_close(m_handle); + } + private: + static esp_err_t EventHandler(esp_http_client_event_t* evt); + esp_err_t EventHeaderHandler(std::string key, std::string value); + + esp_http_client_handle_t m_handle; + bool m_reading; + uint32_t m_retryAfterSeconds; + std::map m_headers; + }; +} // namespace OpenShock::HTTP diff --git a/include/http/HTTPError.h b/include/http/HTTPError.h new file mode 100644 index 00000000..37d3ef37 --- /dev/null +++ b/include/http/HTTPError.h @@ -0,0 +1,47 @@ +#pragma once + +namespace OpenShock::HTTP { + enum class HTTPError { + None, + ClientBusy, + InternalError, + RateLimited, + InvalidUrl, + InvalidHttpMethod, + NetworkError, + ConnectionClosed, + SizeLimitExceeded, + Aborted, + ParseFailed + }; + + inline const char* HTTPErrorToString(HTTPError error) { + switch (error) + { + case HTTPError::None: + return "None"; + case HTTPError::ClientBusy: + return "ClientBusy"; + case HTTPError::InternalError: + return "InternalError"; + case HTTPError::RateLimited: + return "RateLimited"; + case HTTPError::InvalidUrl: + return "InvalidUrl"; + case HTTPError::InvalidHttpMethod: + return "InvalidHttpMethod"; + case HTTPError::NetworkError: + return "NetworkError"; + case HTTPError::ConnectionClosed: + return "ConnectionClosed"; + case HTTPError::SizeLimitExceeded: + return "SizeLimitExceeded"; + case HTTPError::Aborted: + return "Aborted"; + case HTTPError::ParseFailed: + return "ParseFailed"; + default: + return "Unknown"; + } + } +} // namespace OpenShock::HTTP diff --git a/include/http/HTTPRequestManager.h b/include/http/HTTPRequestManager.h deleted file mode 100644 index 19ef5564..00000000 --- a/include/http/HTTPRequestManager.h +++ /dev/null @@ -1,88 +0,0 @@ -#pragma once - -#include - -#include - -#include -#include -#include - -#include "span.h" - -namespace OpenShock::HTTP { - enum class RequestResult : uint8_t { - InternalError, // Internal error - InvalidURL, // Invalid URL - RequestFailed, // Failed to start request - TimedOut, // Request timed out - RateLimited, // Rate limited (can be both local and global) - CodeRejected, // Request completed, but response code was not OK - ParseFailed, // Request completed, but JSON parsing failed - Cancelled, // Request was cancelled - Success, // Request completed successfully - }; - - template - struct [[nodiscard]] Response { - RequestResult result; - int code; - T data; - - inline const char* ResultToString() const - { - switch (result) { - case RequestResult::InternalError: - return "Internal error"; - case RequestResult::InvalidURL: - return "Requested url was invalid"; - case RequestResult::RequestFailed: - return "Request failed"; - case RequestResult::TimedOut: - return "Request timed out"; - case RequestResult::RateLimited: - return "Client was ratelimited"; - case RequestResult::CodeRejected: - return "Unexpected response code"; - case RequestResult::ParseFailed: - return "Parsing the response failed"; - case RequestResult::Cancelled: - return "Request was cancelled"; - case RequestResult::Success: - return "Success"; - default: - return "Unknown reason"; - } - } - }; - - template - using JsonParser = std::function; - using GotContentLengthCallback = std::function; - using DownloadCallback = std::function; - - Response Download(std::string_view url, const std::map& headers, GotContentLengthCallback contentLengthCallback, DownloadCallback downloadCallback, tcb::span acceptedCodes, uint32_t timeoutMs = 10'000); - Response GetString(std::string_view url, const std::map& headers, tcb::span acceptedCodes, uint32_t timeoutMs = 10'000); - - template - Response GetJSON(std::string_view url, const std::map& headers, JsonParser jsonParser, tcb::span acceptedCodes, uint32_t timeoutMs = 10'000) { - auto response = GetString(url, headers, acceptedCodes, timeoutMs); - if (response.result != RequestResult::Success) { - return {response.result, response.code, {}}; - } - - cJSON* json = cJSON_ParseWithLength(response.data.c_str(), response.data.length()); - if (json == nullptr) { - return {RequestResult::ParseFailed, response.code, {}}; - } - - T data; - if (!jsonParser(response.code, json, data)) { - return {RequestResult::ParseFailed, response.code, {}}; - } - - cJSON_Delete(json); - - return {response.result, response.code, std::move(data)}; - } -} // namespace OpenShock::HTTP diff --git a/include/http/HTTPResponse.h b/include/http/HTTPResponse.h new file mode 100644 index 00000000..e5cdaf26 --- /dev/null +++ b/include/http/HTTPResponse.h @@ -0,0 +1,88 @@ +#pragma once + +#include "Common.h" +#include "http/DownloadCallback.h" +#include "http/HTTPClientState.h" +#include "http/JsonParserFn.h" +#include "http/ReadResult.h" + +#include +#include +#include +#include + +namespace OpenShock::HTTP { + class HTTPClient; + class [[nodiscard]] HTTPResponse { + DISABLE_DEFAULT(HTTPResponse); + DISABLE_COPY(HTTPResponse); + DISABLE_MOVE(HTTPResponse); + + friend class HTTPClient; + + HTTPResponse(std::shared_ptr state, uint16_t statusCode, uint32_t contentLength, std::map headers) + : m_state(state) + , m_error(HTTPError::None) + , m_retryAfterSeconds(0) + , m_statusCode(statusCode) + , m_contentLength(contentLength) + , m_headers(std::move(headers)) + { + } + public: + HTTPResponse(HTTPError error) + : m_state() + , m_error(error) + , m_retryAfterSeconds() + , m_statusCode(0) + , m_contentLength(0) + , m_headers() + { + } + HTTPResponse(HTTPError error, uint32_t retryAfterSeconds) + : m_state() + , m_error(error) + , m_retryAfterSeconds(retryAfterSeconds) + , m_statusCode(0) + , m_contentLength(0) + , m_headers() + { + } + + inline bool Ok() const { return m_error == HTTPError::None && !m_state.expired(); } + inline HTTPError Error() const { return m_error; } + inline uint32_t RetryAfterSeconds() const { return m_retryAfterSeconds; } + inline uint16_t StatusCode() const { return m_statusCode; } + inline uint32_t ContentLength() const { return m_contentLength; } + + inline ReadResult ReadStream(DownloadCallback downloadCallback) { + auto locked = m_state.lock(); + if (locked == nullptr) return HTTPError::ConnectionClosed; + + return locked->ReadStreamImpl(downloadCallback); + } + + inline ReadResult ReadString() { + auto locked = m_state.lock(); + if (locked == nullptr) return HTTPError::ConnectionClosed; + + return locked->ReadStringImpl(m_contentLength); + } + + template + inline ReadResult ReadJson(JsonParserFn jsonParser) + { + auto locked = m_state.lock(); + if (locked == nullptr) return HTTPError::ConnectionClosed; + + return locked->ReadJsonImpl(m_contentLength, jsonParser); + } + private: + std::weak_ptr m_state; + HTTPError m_error; + uint32_t m_retryAfterSeconds; + uint16_t m_statusCode; + uint32_t m_contentLength; + std::map m_headers; + }; +} // namespace OpenShock::HTTP diff --git a/include/http/JsonAPI.h b/include/http/JsonAPI.h index fc735bb4..258b83ce 100644 --- a/include/http/JsonAPI.h +++ b/include/http/JsonAPI.h @@ -1,6 +1,7 @@ #pragma once -#include "http/HTTPRequestManager.h" +#include "http/HTTPClient.h" +#include "http/JsonResponse.h" #include "serialization/JsonAPI.h" #include @@ -9,15 +10,15 @@ namespace OpenShock::HTTP::JsonAPI { /// @brief Links the hub to the account with the given account link code, returns the hub token. Valid response codes: 200, 404 /// @param hubToken /// @return - HTTP::Response LinkAccount(std::string_view accountLinkCode); + JsonResponse LinkAccount(std::string_view accountLinkCode); /// @brief Gets the hub info for the given hub token. Valid response codes: 200, 401 /// @param hubToken /// @return - HTTP::Response GetHubInfo(std::string_view hubToken); + JsonResponse GetHubInfo(const char* hubToken); /// @brief Requests a Live Control Gateway to connect to. Valid response codes: 200, 401 /// @param hubToken /// @return - HTTP::Response AssignLcg(std::string_view hubToken); + JsonResponse AssignLcg(const char* hubToken); } // namespace OpenShock::HTTP::JsonAPI diff --git a/include/http/JsonParserFn.h b/include/http/JsonParserFn.h new file mode 100644 index 00000000..40e8b6c9 --- /dev/null +++ b/include/http/JsonParserFn.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +class cJSON; + +namespace OpenShock::HTTP { + template + using JsonParserFn = std::function; +} diff --git a/include/http/JsonResponse.h b/include/http/JsonResponse.h new file mode 100644 index 00000000..a25c7181 --- /dev/null +++ b/include/http/JsonResponse.h @@ -0,0 +1,78 @@ +#pragma once + +#include "Common.h" +#include "http/DownloadCallback.h" +#include "http/HTTPClientState.h" +#include "http/JsonParserFn.h" +#include "http/ReadResult.h" + +#include +#include +#include +#include + +namespace OpenShock::HTTP { + class HTTPClient; + template + class [[nodiscard]] JsonResponse { + DISABLE_DEFAULT(JsonResponse); + DISABLE_COPY(JsonResponse); + DISABLE_MOVE(JsonResponse); + + friend class HTTPClient; + + JsonResponse(std::shared_ptr state, JsonParserFn jsonParser, uint16_t statusCode, uint32_t contentLength, std::map headers) + : m_state(state) + , m_jsonParser(jsonParser) + , m_error(HTTPError::None) + , m_retryAfterSeconds(0) + , m_statusCode(statusCode) + , m_contentLength(contentLength) + , m_headers(std::move(headers)) + { + } + public: + JsonResponse(HTTPError error) + : m_state() + , m_jsonParser() + , m_error(error) + , m_retryAfterSeconds(0) + , m_statusCode(0) + , m_contentLength(0) + , m_headers() + { + } + JsonResponse(HTTPError error, uint32_t retryAfterSeconds) + : m_state() + , m_jsonParser() + , m_error(error) + , m_retryAfterSeconds(retryAfterSeconds) + , m_statusCode(0) + , m_contentLength(0) + , m_headers() + { + } + + inline bool Ok() const { return m_error == HTTPError::None && !m_state.expired(); } + inline HTTPError Error() const { return m_error; } + inline uint32_t RetryAfterSeconds() const { return m_retryAfterSeconds; } + inline uint16_t StatusCode() const { return m_statusCode; } + inline uint32_t ContentLength() const { return m_contentLength; } + + inline ReadResult ReadJson() + { + auto locked = m_state.lock(); + if (locked == nullptr) return HTTPError::ConnectionClosed; + + return locked->ReadJsonImpl(m_contentLength, m_jsonParser); + } + private: + std::weak_ptr m_state; + JsonParserFn m_jsonParser; + HTTPError m_error; + uint32_t m_retryAfterSeconds; + uint16_t m_statusCode; + uint32_t m_contentLength; + std::map m_headers; + }; +} // namespace OpenShock::HTTP diff --git a/include/http/RateLimiters.h b/include/http/RateLimiters.h new file mode 100644 index 00000000..cd37c726 --- /dev/null +++ b/include/http/RateLimiters.h @@ -0,0 +1,10 @@ +#pragma once + +#include "RateLimiter.h" + +#include +#include + +namespace OpenShock::HTTP::RateLimiters { + std::shared_ptr GetRateLimiter(std::string_view url); +} // namespace OpenShock::HTTP diff --git a/include/http/ReadResult.h b/include/http/ReadResult.h new file mode 100644 index 00000000..e4577e63 --- /dev/null +++ b/include/http/ReadResult.h @@ -0,0 +1,17 @@ +#pragma once + +#include "http/HTTPError.h" + +namespace OpenShock::HTTP { + template + struct [[nodiscard]] ReadResult { + HTTPError error{}; + T data{}; + + ReadResult(const T& d) + : error(HTTPError::None), data(d) {} + + ReadResult(const HTTPError& e) + : error(e), data{} {} + }; +} // namespace OpenShock::HTTP diff --git a/include/serialization/JsonAPI.h b/include/serialization/JsonAPI.h index 322e8c8b..ddf3a568 100644 --- a/include/serialization/JsonAPI.h +++ b/include/serialization/JsonAPI.h @@ -2,12 +2,12 @@ #include "ShockerModelType.h" -#include - #include #include #include +class cJSON; + namespace OpenShock::Serialization::JsonAPI { struct LcgInstanceDetailsResponse { std::string name; @@ -41,9 +41,9 @@ namespace OpenShock::Serialization::JsonAPI { std::string country; }; - bool ParseLcgInstanceDetailsJsonResponse(int code, const cJSON* root, LcgInstanceDetailsResponse& out); - bool ParseBackendVersionJsonResponse(int code, const cJSON* root, BackendVersionResponse& out); - bool ParseAccountLinkJsonResponse(int code, const cJSON* root, AccountLinkResponse& out); - bool ParseHubInfoJsonResponse(int code, const cJSON* root, HubInfoResponse& out); - bool ParseAssignLcgJsonResponse(int code, const cJSON* root, AssignLcgResponse& out); + bool ParseLcgInstanceDetailsJsonResponse(const cJSON* root, LcgInstanceDetailsResponse& out); + bool ParseBackendVersionJsonResponse(const cJSON* root, BackendVersionResponse& out); + bool ParseAccountLinkJsonResponse(const cJSON* root, AccountLinkResponse& out); + bool ParseHubInfoJsonResponse(const cJSON* root, HubInfoResponse& out); + bool ParseAssignLcgJsonResponse(const cJSON* root, AssignLcgResponse& out); } // namespace OpenShock::Serialization::JsonAPI diff --git a/include/util/DomainUtils.h b/include/util/DomainUtils.h new file mode 100644 index 00000000..99ad7248 --- /dev/null +++ b/include/util/DomainUtils.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace OpenShock::DomainUtils { + std::string_view GetDomainFromUrl(std::string_view url); +} // namespace OpenShock::DomainUtils diff --git a/include/util/PartitionUtils.h b/include/util/PartitionUtils.h index 22669020..0745894c 100644 --- a/include/util/PartitionUtils.h +++ b/include/util/PartitionUtils.h @@ -8,5 +8,5 @@ namespace OpenShock { bool TryGetPartitionHash(const esp_partition_t* partition, char (&hash)[65]); - bool FlashPartitionFromUrl(const esp_partition_t* partition, std::string_view remoteUrl, const uint8_t (&remoteHash)[32], std::function progressCallback = nullptr); + bool FlashPartitionFromUrl(const esp_partition_t* partition, const char* remoteUrl, const uint8_t (&remoteHash)[32], std::function progressCallback = nullptr); } diff --git a/src/GatewayConnectionManager.cpp b/src/GatewayConnectionManager.cpp index 3204373c..d3da4636 100644 --- a/src/GatewayConnectionManager.cpp +++ b/src/GatewayConnectionManager.cpp @@ -106,32 +106,37 @@ AccountLinkResultCode GatewayConnectionManager::Link(std::string_view linkCode) } auto response = HTTP::JsonAPI::LinkAccount(linkCode); + if (!response.Ok()) { - if (response.code == 404) { - return AccountLinkResultCode::InvalidCode; + if (response.Error() == HTTP::HTTPError::RateLimited) { + return AccountLinkResultCode::InternalError; // Just return false, don't spam the console with errors + } + + OS_LOGE(TAG, "Error while linking account: %s %d", HTTP::HTTPErrorToString(response.Error()), response.StatusCode()); + return AccountLinkResultCode::InternalError; } - if (response.result == HTTP::RequestResult::RateLimited) { - OS_LOGW(TAG, "Account Link request got ratelimited"); - return AccountLinkResultCode::RateLimited; + if (response.StatusCode() == 404) { + return AccountLinkResultCode::InvalidCode; } - if (response.result != HTTP::RequestResult::Success) { - OS_LOGE(TAG, "Error while getting auth token: %s %d", response.ResultToString(), response.code); + if (response.StatusCode() != 200) { + OS_LOGE(TAG, "Unexpected response code: %d", response.StatusCode()); return AccountLinkResultCode::InternalError; } - if (response.code != 200) { - OS_LOGE(TAG, "Unexpected response code: %d", response.code); + auto content = response.ReadJson(); + if (content.error != HTTP::HTTPError::None) { + OS_LOGE(TAG, "Error while reading response: %s %d", HTTP::HTTPErrorToString(response.Error()), response.StatusCode()); return AccountLinkResultCode::InternalError; } - if (response.data.authToken.empty()) { + if (content.data.authToken.empty()) { OS_LOGE(TAG, "Received empty auth token"); return AccountLinkResultCode::InternalError; } - if (!Config::SetBackendAuthToken(std::move(response.data.authToken))) { + if (!Config::SetBackendAuthToken(std::move(content.data.authToken))) { OS_LOGE(TAG, "Failed to save auth token"); return AccountLinkResultCode::InternalError; } @@ -166,7 +171,7 @@ bool GatewayConnectionManager::SendMessageBIN(tcb::span data) return s_wsClient->sendMessageBIN(data); } -bool FetchHubInfo(std::string authToken) +bool FetchHubInfo(const char* authToken) { // TODO: this function is very slow, should be optimized! if ((s_flags & FLAG_HAS_IP) == 0) { @@ -177,31 +182,37 @@ bool FetchHubInfo(std::string authToken) return false; } - auto response = HTTP::JsonAPI::GetHubInfo(std::move(authToken)); + auto response = HTTP::JsonAPI::GetHubInfo(authToken); + if (!response.Ok()) { + if (response.Error() == HTTP::HTTPError::RateLimited) { + return false; // Just return false, don't spam the console with errors + } - if (response.code == 401) { - OS_LOGD(TAG, "Auth token is invalid, waiting 5 minutes before checking again"); - s_lastAuthFailure = OpenShock::micros(); + OS_LOGE(TAG, "Error while fetching hub info: %s %d", HTTP::HTTPErrorToString(response.Error()), response.StatusCode()); return false; } - if (response.result == HTTP::RequestResult::RateLimited) { - return false; // Just return false, don't spam the console with errors + if (response.StatusCode() == 401) { + OS_LOGD(TAG, "Auth token is invalid, waiting 5 minutes before retrying"); + s_lastAuthFailure = OpenShock::micros(); + return false; } - if (response.result != HTTP::RequestResult::Success) { - OS_LOGE(TAG, "Error while fetching hub info: %s %d", response.ResultToString(), response.code); + + if (response.StatusCode() != 200) { + OS_LOGE(TAG, "Unexpected response code: %d", response.StatusCode()); return false; } - if (response.code != 200) { - OS_LOGE(TAG, "Unexpected response code: %d", response.code); + auto content = response.ReadJson(); + if (content.error != HTTP::HTTPError::None) { + OS_LOGE(TAG, "Error while reading response: %s %d", HTTP::HTTPErrorToString(response.Error()), response.StatusCode()); return false; } - OS_LOGI(TAG, "Hub ID: %s", response.data.hubId.c_str()); - OS_LOGI(TAG, "Hub Name: %s", response.data.hubName.c_str()); + OS_LOGI(TAG, "Hub ID: %s", content.data.hubId.c_str()); + OS_LOGI(TAG, "Hub Name: %s", content.data.hubName.c_str()); OS_LOGI(TAG, "Shockers:"); - for (auto& shocker : response.data.shockers) { + for (auto& shocker : content.data.shockers) { OS_LOGI(TAG, " [%s] rf=%u model=%u", shocker.id.c_str(), shocker.rfId, shocker.model); } @@ -241,29 +252,35 @@ bool StartConnectingToLCG() return false; } - auto response = HTTP::JsonAPI::AssignLcg(std::move(authToken)); + auto response = HTTP::JsonAPI::AssignLcg(authToken.c_str()); + if (!response.Ok()) { + if (response.Error() == HTTP::HTTPError::RateLimited) { + return false; // Just return false, don't spam the console with errors + } + + OS_LOGE(TAG, "Error while fetching LCG endpoint: %s %d", HTTP::HTTPErrorToString(response.Error()), response.StatusCode()); + return false; + } - if (response.code == 401) { + if (response.StatusCode() == 401) { OS_LOGD(TAG, "Auth token is invalid, waiting 5 minutes before retrying"); s_lastAuthFailure = OpenShock::micros(); return false; } - if (response.result == HTTP::RequestResult::RateLimited) { - return false; // Just return false, don't spam the console with errors - } - if (response.result != HTTP::RequestResult::Success) { - OS_LOGE(TAG, "Error while fetching LCG endpoint: %s %d", response.ResultToString(), response.code); + if (response.StatusCode() != 200) { + OS_LOGE(TAG, "Unexpected response code: %d", response.StatusCode()); return false; } - if (response.code != 200) { - OS_LOGE(TAG, "Unexpected response code: %d", response.code); + auto content = response.ReadJson(); + if (content.error != HTTP::HTTPError::None) { + OS_LOGE(TAG, "Error while reading response: %s %d", HTTP::HTTPErrorToString(response.Error()), response.StatusCode()); return false; } - OS_LOGD(TAG, "Connecting to LCG endpoint { host: '%s', port: %hu, path: '%s' } in country %s", response.data.host.c_str(), response.data.port, response.data.path.c_str(), response.data.country.c_str()); - s_wsClient->connect(response.data.host, response.data.port, response.data.path); + OS_LOGD(TAG, "Connecting to LCG endpoint { host: '%s', port: %hu, path: '%s' } in country %s", content.data.host.c_str(), content.data.port, content.data.path.c_str(), content.data.country.c_str()); + s_wsClient->connect(content.data.host, content.data.port, content.data.path); return true; } @@ -283,7 +300,7 @@ void GatewayConnectionManager::Update() } // Fetch hub info - if (!FetchHubInfo(std::move(authToken))) { + if (!FetchHubInfo(authToken.c_str())) { return; } diff --git a/src/OtaUpdateManager.cpp b/src/OtaUpdateManager.cpp index 77c5ed40..b366996b 100644 --- a/src/OtaUpdateManager.cpp +++ b/src/OtaUpdateManager.cpp @@ -1,3 +1,8 @@ +#include + +#include +#include + #include "OtaUpdateManager.h" const char* const TAG = "OtaUpdateManager"; @@ -8,7 +13,7 @@ const char* const TAG = "OtaUpdateManager"; #include "Core.h" #include "GatewayConnectionManager.h" #include "Hashing.h" -#include "http/HTTPRequestManager.h" +#include "http/HTTPClient.h" #include "Logging.h" #include "SemVer.h" #include "serialization/WSGateway.h" @@ -22,9 +27,6 @@ const char* const TAG = "OtaUpdateManager"; #include #include -#include -#include - #include #include @@ -159,7 +161,7 @@ static bool _sendFailureMessage(std::string_view message, bool fatal = false) return true; } -static bool _flashAppPartition(const esp_partition_t* partition, std::string_view remoteUrl, const uint8_t (&remoteHash)[32]) +static bool _flashAppPartition(const esp_partition_t* partition, const char* remoteUrl, const uint8_t (&remoteHash)[32]) { OS_LOGD(TAG, "Flashing app partition"); @@ -195,7 +197,7 @@ static bool _flashAppPartition(const esp_partition_t* partition, std::string_vie return true; } -static bool _flashFilesystemPartition(const esp_partition_t* parition, std::string_view remoteUrl, const uint8_t (&remoteHash)[32]) +static bool _flashFilesystemPartition(const esp_partition_t* parition, const char* remoteUrl, const uint8_t (&remoteHash)[32]) { if (!_sendProgressMessage(Serialization::Types::OtaUpdateProgressTask::PreparingForUpdate, 0.0f)) { return false; @@ -396,8 +398,8 @@ static void otaum_updatetask(void* arg) esp_task_wdt_init(15, true); // Flash app and filesystem partitions. - if (!_flashFilesystemPartition(filesystemPartition, release.filesystemBinaryUrl, release.filesystemBinaryHash)) continue; - if (!_flashAppPartition(appPartition, release.appBinaryUrl, release.appBinaryHash)) continue; + if (!_flashFilesystemPartition(filesystemPartition, release.filesystemBinaryUrl.c_str(), release.filesystemBinaryHash)) continue; + if (!_flashAppPartition(appPartition, release.appBinaryUrl.c_str(), release.appBinaryHash)) continue; // Set OTA boot type in config. if (!Config::SetOtaUpdateStep(OpenShock::OtaUpdateStep::Updated)) { @@ -422,25 +424,30 @@ static void otaum_updatetask(void* arg) esp_restart(); } -static bool _tryGetStringList(std::string_view url, std::vector& list) +static bool _tryGetStringList(const char* url, std::vector& list) { - auto response = OpenShock::HTTP::GetString( - url, - { - {"Accept", "text/plain"} - }, - std::array {200, 304} - ); - if (response.result != OpenShock::HTTP::RequestResult::Success) { - OS_LOGE(TAG, "Failed to fetch list: %s [%u] %s", response.ResultToString(), response.code, response.data.c_str()); + HTTP::HTTPClient client(url); + auto response = client.Get(); + if (!response.Ok()) { + OS_LOGE(TAG, "Failed to fetch list"); return false; } - list.clear(); + uint16_t statusCode = response.StatusCode(); + if (statusCode != 200 && statusCode != 304) { + OS_LOGE(TAG, "Failed to fetch list"); + return false; + } + + auto content = response.ReadString(); + if (content.error != HTTP::HTTPError::None) { + OS_LOGE(TAG, "Failed to fetch list: %s [%u] %s", HTTP::HTTPErrorToString(response.Error()), response.StatusCode(), content.data.c_str()); + return false; + } - std::string_view data = response.data; + list.clear(); - auto lines = OpenShock::StringSplitNewLines(data); + auto lines = OpenShock::StringSplitNewLines(content.data); list.reserve(lines.size()); for (auto line : lines) { @@ -525,16 +532,16 @@ bool OtaUpdateManager::Init() bool OtaUpdateManager::TryGetFirmwareVersion(OtaUpdateChannel channel, OpenShock::SemVer& version) { - std::string_view channelIndexUrl; + const char* channelIndexUrl; switch (channel) { case OtaUpdateChannel::Stable: - channelIndexUrl = OPENSHOCK_FW_CDN_STABLE_URL ""sv; + channelIndexUrl = OPENSHOCK_FW_CDN_STABLE_URL; break; case OtaUpdateChannel::Beta: - channelIndexUrl = OPENSHOCK_FW_CDN_BETA_URL ""sv; + channelIndexUrl = OPENSHOCK_FW_CDN_BETA_URL; break; case OtaUpdateChannel::Develop: - channelIndexUrl = OPENSHOCK_FW_CDN_DEVELOP_URL ""sv; + channelIndexUrl = OPENSHOCK_FW_CDN_DEVELOP_URL; break; default: OS_LOGE(TAG, "Unknown channel: %u", channel); @@ -543,20 +550,27 @@ bool OtaUpdateManager::TryGetFirmwareVersion(OtaUpdateChannel channel, OpenShock OS_LOGD(TAG, "Fetching firmware version from %s", channelIndexUrl); - auto response = OpenShock::HTTP::GetString( - channelIndexUrl, - { - {"Accept", "text/plain"} - }, - std::array {200, 304} - ); - if (response.result != OpenShock::HTTP::RequestResult::Success) { - OS_LOGE(TAG, "Failed to fetch firmware version: %s [%u] %s", response.ResultToString(), response.code, response.data.c_str()); + HTTP::HTTPClient client(channelIndexUrl); + auto response = client.Get(); + if (!response.Ok()) { + OS_LOGE(TAG, "Failed to fetch firmware version"); + return false; + } + + uint16_t statusCode = response.StatusCode(); + if (statusCode != 200 && statusCode != 304) { + OS_LOGE(TAG, "Failed to fetch firmware version"); + return false; + } + + auto content = response.ReadString(); + if (content.error != HTTP::HTTPError::None) { + OS_LOGE(TAG, "Failed to fetch firmware version: %s [%u] %s", HTTP::HTTPErrorToString(response.Error()), response.StatusCode(), content.data.c_str()); return false; } - if (!OpenShock::TryParseSemVer(response.data, version)) { - OS_LOGE(TAG, "Failed to parse firmware version: %.*s", response.data.size(), response.data.data()); + if (!OpenShock::TryParseSemVer(content.data, version)) { + OS_LOGE(TAG, "Failed to parse firmware version: %.*s", content.data.size(), content.data.data()); return false; } @@ -573,7 +587,7 @@ bool OtaUpdateManager::TryGetFirmwareBoards(const OpenShock::SemVer& version, st OS_LOGD(TAG, "Fetching firmware boards from %s", channelIndexUrl.c_str()); - if (!_tryGetStringList(channelIndexUrl, boards)) { + if (!_tryGetStringList(channelIndexUrl.c_str(), boards)) { OS_LOGE(TAG, "Failed to fetch firmware boards"); return false; } @@ -613,19 +627,26 @@ bool OtaUpdateManager::TryGetFirmwareRelease(const OpenShock::SemVer& version, F } // Fetch hashes. - auto sha256HashesResponse = OpenShock::HTTP::GetString( - sha256HashesUrl, - { - {"Accept", "text/plain"} - }, - std::array {200, 304} - ); - if (sha256HashesResponse.result != OpenShock::HTTP::RequestResult::Success) { - OS_LOGE(TAG, "Failed to fetch hashes: %s [%u] %s", sha256HashesResponse.ResultToString(), sha256HashesResponse.code, sha256HashesResponse.data.c_str()); + HTTP::HTTPClient client(sha256HashesUrl.c_str()); + auto response = client.Get(); + if (!response.Ok()) { + OS_LOGE(TAG, "Failed to fetch hashes"); + return false; + } + + uint16_t statusCode = response.StatusCode(); + if (statusCode != 200 && statusCode != 304) { + OS_LOGE(TAG, "Failed to fetch hashes"); + return false; + } + + auto content = response.ReadString(); + if (content.error != HTTP::HTTPError::None) { + OS_LOGE(TAG, "Failed to fetch hashes: %s [%u] %s", HTTP::HTTPErrorToString(response.Error()), response.StatusCode(), content.data.c_str()); return false; } - auto hashesLines = OpenShock::StringSplitNewLines(sha256HashesResponse.data); + auto hashesLines = OpenShock::StringSplitNewLines(content.data); // Parse hashes. bool foundAppHash = false, foundFilesystemHash = false; diff --git a/src/http/HTTPClientState.cpp b/src/http/HTTPClientState.cpp new file mode 100644 index 00000000..7e178684 --- /dev/null +++ b/src/http/HTTPClientState.cpp @@ -0,0 +1,238 @@ +#include + +#include "http/HTTPClientState.h" + +const char* const TAG = "HTTPClientState"; + +#include "Common.h" +#include "Convert.h" +#include "Logging.h" + +#include + +static const uint32_t HTTP_BUFFER_SIZE = 4096LLU; +static const uint32_t HTTP_DOWNLOAD_SIZE_LIMIT = 200 * 1024 * 1024; // 200 MB + +using namespace OpenShock; + +HTTP::HTTPClientState::HTTPClientState(const char* url, uint32_t timeoutMs) + : m_handle(nullptr) + , m_reading(false) + , m_retryAfterSeconds(0) + , m_headers() +{ + esp_http_client_config_t cfg; + memset(&cfg, 0, sizeof(cfg)); + + cfg.url = url; + cfg.user_agent = OpenShock::Constants::FW_USERAGENT; + cfg.timeout_ms = static_cast(std::min(timeoutMs, INT32_MAX)); + cfg.disable_auto_redirect = true; + cfg.event_handler = HTTPClientState::EventHandler; + cfg.transport_type = HTTP_TRANSPORT_OVER_SSL; + cfg.user_data = reinterpret_cast(this); + cfg.is_async = false; + cfg.use_global_ca_store = true; + + m_handle = esp_http_client_init(&cfg); +} + +HTTP::HTTPClientState::~HTTPClientState() +{ + if (m_handle != nullptr) { + esp_http_client_cleanup(m_handle); + m_handle = nullptr; + } +} + +esp_err_t HTTP::HTTPClientState::SetUrl(const char* url) +{ + if (m_handle == nullptr) { + return ESP_FAIL; + } + + return esp_http_client_set_url(m_handle, url); +} + +esp_err_t HTTP::HTTPClientState::SetHeader(const char* key, const char* value) +{ + if (m_handle == nullptr) { + return ESP_FAIL; + } + + return esp_http_client_set_header(m_handle, key, value); +} + +HTTP::HTTPClientState::StartRequestResult HTTP::HTTPClientState::StartRequest(esp_http_client_method_t method, int writeLen) +{ + if (m_handle == nullptr) { + return { .error = HTTPError::ConnectionClosed }; + } + + esp_err_t err; + + if (m_reading) { + return { .error = HTTPError::ClientBusy }; + } + + m_retryAfterSeconds = 0; + m_headers.clear(); + + err = esp_http_client_set_method(m_handle, method); + if (err != ESP_OK) return { .error = HTTPError::InvalidHttpMethod }; + + err = esp_http_client_open(m_handle, writeLen); + if (err != ESP_OK) return { .error = HTTPError::NetworkError }; + + int contentLength = esp_http_client_fetch_headers(m_handle); + if (contentLength < 0) return { .error = HTTPError::NetworkError }; + + if (m_retryAfterSeconds > 0) { + uint32_t retryAfterSeconds = m_retryAfterSeconds; + m_retryAfterSeconds = 0; + return { .error = HTTPError::RateLimited, .retryAfterSeconds = retryAfterSeconds }; + } + + bool isChunked = false; + if (contentLength == 0) { + isChunked = esp_http_client_is_chunked_response(m_handle); + } + + int statusCode = esp_http_client_get_status_code(m_handle); + if (statusCode < 0 || statusCode > 599) { + OS_LOGE(TAG, "Returned statusCode is invalid (%i)", statusCode); + return { .error = HTTPError::NetworkError }; + } + + m_reading = true; + + return StartRequestResult { + .statusCode = static_cast(statusCode), + .isChunked = isChunked, + .contentLength = static_cast(contentLength), + .headers = std::move(m_headers) + }; +} + +HTTP::ReadResult HTTP::HTTPClientState::ReadStreamImpl(DownloadCallback cb) +{ + if (m_handle == nullptr || !m_reading) { + m_reading = false; + return HTTPError::ConnectionClosed; + } + + uint32_t totalWritten = 0; + uint8_t buffer[HTTP_BUFFER_SIZE]; + + while (true) { + if (totalWritten >= HTTP_DOWNLOAD_SIZE_LIMIT) { + m_reading = false; + return HTTPError::SizeLimitExceeded; + } + + uint32_t remaining = HTTP_DOWNLOAD_SIZE_LIMIT - totalWritten; + int toRead = static_cast(std::min(HTTP_BUFFER_SIZE, remaining)); + + int n = esp_http_client_read( + m_handle, + reinterpret_cast(buffer), + toRead + ); + + if (n < 0) { + m_reading = false; + return HTTPError::NetworkError; + } + + if (n == 0) { + // EOF + break; + } + + uint32_t chunkLen = static_cast(n); + if (!cb(totalWritten, buffer, chunkLen)) { + m_reading = false; + return HTTPError::Aborted; + } + + totalWritten += chunkLen; + } + + m_reading = false; + return totalWritten; +} + +HTTP::ReadResult HTTP::HTTPClientState::ReadStringImpl(uint32_t reserve) +{ + std::string result; + if (reserve > 0) { + result.reserve(reserve); + } + + auto writer = [&result](std::size_t offset, const uint8_t* data, std::size_t len) { + result.append(reinterpret_cast(data), len); + return true; + }; + + auto response = ReadStreamImpl(writer); + if (response.error != HTTPError::None) { + return response.error; + } + + return result; +} + +esp_err_t HTTP::HTTPClientState::EventHandler(esp_http_client_event_t* evt) +{ + HTTPClientState* client = reinterpret_cast(evt->user_data); + + switch (evt->event_id) + { + case HTTP_EVENT_ERROR: + OS_LOGE(TAG, "Got error event"); + break; + case HTTP_EVENT_ON_CONNECTED: + OS_LOGI(TAG, "Got connected event"); + break; + case HTTP_EVENT_HEADERS_SENT: + OS_LOGI(TAG, "Got headers_sent event"); + break; + case HTTP_EVENT_ON_HEADER: + return client->EventHeaderHandler(evt->header_key, evt->header_value); + case HTTP_EVENT_ON_DATA: + OS_LOGI(TAG, "Got on_data event"); + break; + case HTTP_EVENT_ON_FINISH: + OS_LOGI(TAG, "Got on_finish event"); + break; + case HTTP_EVENT_DISCONNECTED: + OS_LOGI(TAG, "Got disconnected event"); + break; + default: + OS_LOGE(TAG, "Got unknown event"); + break; + } + + return ESP_OK; +} + +esp_err_t HTTP::HTTPClientState::EventHeaderHandler(std::string key, std::string value) +{ + OS_LOGI(TAG, "Got header_received event: %.*s - %.*s", key.length(), key.c_str(), value.length(), value.c_str()); + + std::transform(key.begin(), key.end(), key.begin(), [](unsigned char c) { return std::tolower(c); }); + + if (key == "retry-after") { + uint32_t seconds = 0; + if (!Convert::ToUint32(value, seconds) || seconds <= 0) { + seconds = 15; + } + + OS_LOGI(TAG, "Retry-After: %d seconds, applying delay to rate limiter", seconds); + m_retryAfterSeconds = seconds; + } + + m_headers[key] = std::move(value); + + return ESP_OK; +} diff --git a/src/http/HTTPRequestManager.cpp b/src/http/HTTPRequestManager.cpp deleted file mode 100644 index 63dfe9ba..00000000 --- a/src/http/HTTPRequestManager.cpp +++ /dev/null @@ -1,486 +0,0 @@ -#include "http/HTTPRequestManager.h" - -const char* const TAG = "HTTPRequestManager"; - -#include "Common.h" -#include "Core.h" -#include "Logging.h" -#include "RateLimiter.h" -#include "SimpleMutex.h" -#include "util/HexUtils.h" -#include "util/StringUtils.h" - -#include - -#include -#include -#include -#include -#include - -using namespace std::string_view_literals; - -const std::size_t HTTP_BUFFER_SIZE = 4096LLU; -const int HTTP_DOWNLOAD_SIZE_LIMIT = 200 * 1024 * 1024; // 200 MB - -static OpenShock::SimpleMutex s_rateLimitsMutex = {}; -static std::unordered_map> s_rateLimits = {}; - -using namespace OpenShock; - -std::string_view _getDomain(std::string_view url) -{ - if (url.empty()) { - return {}; - } - - // Remove the protocol eg. "https://api.example.com:443/path" -> "api.example.com:443/path" - auto seperator = url.find("://"); - if (seperator != std::string_view::npos) { - url.substr(seperator + 3); - } - - // Remove the path eg. "api.example.com:443/path" -> "api.example.com:443" - seperator = url.find('/'); - if (seperator != std::string_view::npos) { - url = url.substr(0, seperator); - } - - // Remove the port eg. "api.example.com:443" -> "api.example.com" - seperator = url.rfind(':'); - if (seperator != std::string_view::npos) { - url = url.substr(0, seperator); - } - - // Remove all subdomains eg. "api.example.com" -> "example.com" - seperator = url.rfind('.'); - if (seperator == std::string_view::npos) { - return url; // E.g. "localhost" - } - seperator = url.rfind('.', seperator - 1); - if (seperator != std::string_view::npos) { - url = url.substr(seperator + 1); - } - - return url; -} - -std::shared_ptr _rateLimiterFactory(std::string_view domain) -{ - auto rateLimit = std::make_shared(); - - // Add default limits - rateLimit->addLimit(1000, 5); // 5 per second - rateLimit->addLimit(10 * 1000, 10); // 10 per 10 seconds - - // per-domain limits - if (domain == OPENSHOCK_API_DOMAIN) { - rateLimit->addLimit(60 * 1000, 12); // 12 per minute - rateLimit->addLimit(60 * 60 * 1000, 120); // 120 per hour - } - - return rateLimit; -} - -std::shared_ptr _getRateLimiter(std::string_view url) -{ - auto domain = std::string(_getDomain(url)); - if (domain.empty()) { - return nullptr; - } - - s_rateLimitsMutex.lock(portMAX_DELAY); - - auto it = s_rateLimits.find(domain); - if (it == s_rateLimits.end()) { - s_rateLimits.emplace(domain, _rateLimiterFactory(domain)); - it = s_rateLimits.find(domain); - } - - s_rateLimitsMutex.unlock(); - - return it->second; -} - -void _setupClient(HTTPClient& client) -{ - client.setUserAgent(OpenShock::Constants::FW_USERAGENT); -} - -struct StreamReaderResult { - HTTP::RequestResult result; - std::size_t nWritten; -}; - -constexpr bool _isCRLF(const uint8_t* buffer) -{ - return buffer[0] == '\r' && buffer[1] == '\n'; -} -constexpr bool _tryFindCRLF(std::size_t& pos, const uint8_t* buffer, std::size_t len) -{ - const uint8_t* cur = buffer; - const uint8_t* end = buffer + len - 1; - - while (cur < end) { - if (_isCRLF(cur)) { - pos = static_cast(cur - buffer); - return true; - } - - ++cur; - } - - return false; -} - -enum ParserState : uint8_t { - Ok, - NeedMoreData, - Invalid, -}; - -ParserState _parseChunkHeader(const uint8_t* buffer, std::size_t bufferLen, std::size_t& headerLen, std::size_t& payloadLen) -{ - if (bufferLen < 5) { // Bare minimum: "0\r\n\r\n" - return ParserState::NeedMoreData; - } - - // Find the first CRLF - if (!_tryFindCRLF(headerLen, buffer, bufferLen)) { - return ParserState::NeedMoreData; - } - - // Header must have at least one character - if (headerLen == 0) { - OS_LOGW(TAG, "Invalid chunk header length"); - return ParserState::Invalid; - } - - // Check for end of size field (possibly followed by extensions which is separated by a semicolon) - std::size_t sizeFieldEnd = headerLen; - for (std::size_t i = 0; i < headerLen; ++i) { - if (buffer[i] == ';') { - sizeFieldEnd = i; - break; - } - } - - // Bounds check - if (sizeFieldEnd == 0 || sizeFieldEnd > 16) { - OS_LOGW(TAG, "Invalid chunk size field length"); - return ParserState::Invalid; - } - - std::string_view sizeField(reinterpret_cast(buffer), sizeFieldEnd); - - // Parse the chunk size - if (!HexUtils::TryParseHexToInt(sizeField.data(), sizeField.length(), payloadLen)) { - OS_LOGW(TAG, "Failed to parse chunk size"); - return ParserState::Invalid; - } - - if (payloadLen > HTTP_DOWNLOAD_SIZE_LIMIT) { - OS_LOGW(TAG, "Chunk size too large"); - return ParserState::Invalid; - } - - // Set the header length to the end of the CRLF - headerLen += 2; - - return ParserState::Ok; -} - -ParserState _parseChunk(const uint8_t* buffer, std::size_t bufferLen, std::size_t& payloadPos, std::size_t& payloadLen) -{ - if (payloadPos == 0) { - ParserState state = _parseChunkHeader(buffer, bufferLen, payloadPos, payloadLen); - if (state != ParserState::Ok) { - return state; - } - } - - std::size_t totalLen = payloadPos + payloadLen + 2; // +2 for CRLF - if (bufferLen < totalLen) { - return ParserState::NeedMoreData; - } - - // Check for CRLF - if (!_isCRLF(buffer + totalLen - 2)) { - OS_LOGW(TAG, "Invalid chunk payload CRLF"); - return ParserState::Invalid; - } - - return ParserState::Ok; -} - -void _alignChunk(uint8_t* buffer, std::size_t& bufferCursor, std::size_t payloadPos, std::size_t payloadLen) -{ - std::size_t totalLen = payloadPos + payloadLen + 2; // +2 for CRLF - std::size_t remaining = bufferCursor - totalLen; - if (remaining > 0) { - memmove(buffer, buffer + totalLen, remaining); - bufferCursor = remaining; - } else { - bufferCursor = 0; - } -} - -StreamReaderResult _readStreamDataChunked(HTTPClient& client, WiFiClient* stream, HTTP::DownloadCallback downloadCallback, int64_t begin, uint32_t timeoutMs) -{ - std::size_t totalWritten = 0; - HTTP::RequestResult result = HTTP::RequestResult::Success; - - uint8_t* buffer = static_cast(malloc(HTTP_BUFFER_SIZE)); - if (buffer == nullptr) { - OS_LOGE(TAG, "Out of memory"); - return {HTTP::RequestResult::RequestFailed, 0}; - } - - ParserState state = ParserState::NeedMoreData; - std::size_t bufferCursor = 0, payloadPos = 0, payloadSize = 0; - - while (client.connected() && state != ParserState::Invalid) { - if (begin + timeoutMs < OpenShock::millis()) { - OS_LOGW(TAG, "Request timed out"); - result = HTTP::RequestResult::TimedOut; - break; - } - - std::size_t bytesAvailable = stream->available(); - if (bytesAvailable == 0) { - vTaskDelay(pdMS_TO_TICKS(5)); - continue; - } - - std::size_t bytesRead = stream->readBytes(buffer + bufferCursor, HTTP_BUFFER_SIZE - bufferCursor); - if (bytesRead == 0) { - OS_LOGW(TAG, "No bytes read"); - result = HTTP::RequestResult::RequestFailed; - break; - } - - bufferCursor += bytesRead; - - while (bufferCursor > 0) { - state = _parseChunk(buffer, bufferCursor, payloadPos, payloadSize); - if (state == ParserState::Invalid) { - OS_LOGE(TAG, "Failed to parse chunk"); - result = HTTP::RequestResult::RequestFailed; - state = ParserState::Invalid; // Mark to exit both loops - break; - } - OS_LOGD(TAG, "Chunk parsed: %zu %zu", payloadPos, payloadSize); - - if (state == ParserState::NeedMoreData) { - if (bufferCursor == HTTP_BUFFER_SIZE) { - OS_LOGE(TAG, "Chunk too large"); - result = HTTP::RequestResult::RequestFailed; - state = ParserState::Invalid; // Mark to exit both loops - } - break; // If chunk size good, this only exits one loop - } - - // Check for zero chunk size (end of transfer) - if (payloadSize == 0) { - state = ParserState::Invalid; // Mark to exit both loops - break; - } - - if (!downloadCallback(totalWritten, buffer + payloadPos, payloadSize)) { - result = HTTP::RequestResult::Cancelled; - state = ParserState::Invalid; // Mark to exit both loops - break; - } - - totalWritten += payloadSize; - _alignChunk(buffer, bufferCursor, payloadPos, payloadSize); - payloadSize = 0; - payloadPos = 0; - } - - if (state == ParserState::NeedMoreData) { - vTaskDelay(pdMS_TO_TICKS(5)); - } - } - - free(buffer); - - return {result, totalWritten}; -} - -StreamReaderResult _readStreamData(HTTPClient& client, WiFiClient* stream, std::size_t contentLength, HTTP::DownloadCallback downloadCallback, int64_t begin, uint32_t timeoutMs) -{ - std::size_t nWritten = 0; - HTTP::RequestResult result = HTTP::RequestResult::Success; - - uint8_t* buffer = static_cast(malloc(HTTP_BUFFER_SIZE)); - - while (client.connected() && nWritten < contentLength) { - if (begin + timeoutMs < OpenShock::millis()) { - OS_LOGW(TAG, "Request timed out"); - result = HTTP::RequestResult::TimedOut; - break; - } - - std::size_t bytesAvailable = stream->available(); - if (bytesAvailable == 0) { - vTaskDelay(pdMS_TO_TICKS(5)); - continue; - } - - std::size_t bytesToRead = std::min(bytesAvailable, HTTP_BUFFER_SIZE); - - std::size_t bytesRead = stream->readBytes(buffer, bytesToRead); - if (bytesRead == 0) { - OS_LOGW(TAG, "No bytes read"); - result = HTTP::RequestResult::RequestFailed; - break; - } - - if (!downloadCallback(nWritten, buffer, bytesRead)) { - OS_LOGW(TAG, "Request cancelled by callback"); - result = HTTP::RequestResult::Cancelled; - break; - } - - nWritten += bytesRead; - - vTaskDelay(pdMS_TO_TICKS(10)); - } - - free(buffer); - - return {result, nWritten}; -} - -HTTP::Response _doGetStream( - HTTPClient& client, - std::string_view url, - const std::map& headers, - tcb::span acceptedCodes, - std::shared_ptr rateLimiter, - HTTP::GotContentLengthCallback contentLengthCallback, - HTTP::DownloadCallback downloadCallback, - uint32_t timeoutMs -) -{ - int64_t begin = OpenShock::millis(); - if (!client.begin(OpenShock::StringToArduinoString(url))) { - OS_LOGE(TAG, "Failed to begin HTTP request"); - return {HTTP::RequestResult::RequestFailed, 0, 0}; - } - - for (auto& header : headers) { - client.addHeader(header.first, header.second); - } - - int responseCode = client.GET(); - - if (responseCode == HTTP_CODE_REQUEST_TIMEOUT || begin + timeoutMs < OpenShock::millis()) { - OS_LOGW(TAG, "Request timed out"); - return {HTTP::RequestResult::TimedOut, responseCode, 0}; - } - - if (responseCode == HTTP_CODE_TOO_MANY_REQUESTS) { - // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After - - // Get "Retry-After" header - String retryAfterStr = client.header("Retry-After"); - - // Try to parse it as an integer (delay-seconds) - long retryAfter = 0; - if (retryAfterStr.length() > 0 && std::all_of(retryAfterStr.begin(), retryAfterStr.end(), isdigit)) { - retryAfter = retryAfterStr.toInt(); - } - - // If header missing/unparseable, default to 15 seconds - if (retryAfter <= 0) { - retryAfter = 15; - } - - // Apply the block-for time - rateLimiter->blockFor(retryAfter * 1000); - - return {HTTP::RequestResult::RateLimited, responseCode, 0}; - } - - if (responseCode == 418) { - OS_LOGW(TAG, "The server refused to brew coffee because it is, permanently, a teapot."); - } - - if (std::find(acceptedCodes.begin(), acceptedCodes.end(), responseCode) == acceptedCodes.end()) { - OS_LOGD(TAG, "Received unexpected response code %d", responseCode); - return {HTTP::RequestResult::CodeRejected, responseCode, 0}; - } - - int contentLength = client.getSize(); - if (contentLength == 0) { - return {HTTP::RequestResult::Success, responseCode, 0}; - } - - if (contentLength > 0) { - if (contentLength > HTTP_DOWNLOAD_SIZE_LIMIT) { - OS_LOGE(TAG, "Content-Length too large"); - return {HTTP::RequestResult::RequestFailed, responseCode, 0}; - } - - if (!contentLengthCallback(contentLength)) { - OS_LOGW(TAG, "Request cancelled by callback"); - return {HTTP::RequestResult::Cancelled, responseCode, 0}; - } - } - - WiFiClient* stream = client.getStreamPtr(); - if (stream == nullptr) { - OS_LOGE(TAG, "Failed to get stream"); - return {HTTP::RequestResult::RequestFailed, 0, 0}; - } - - StreamReaderResult result; - if (contentLength > 0) { - result = _readStreamData(client, stream, contentLength, downloadCallback, begin, timeoutMs); - } else { - result = _readStreamDataChunked(client, stream, downloadCallback, begin, timeoutMs); - } - - return {result.result, responseCode, result.nWritten}; -} - -HTTP::Response - HTTP::Download(std::string_view url, const std::map& headers, HTTP::GotContentLengthCallback contentLengthCallback, HTTP::DownloadCallback downloadCallback, tcb::span acceptedCodes, uint32_t timeoutMs) -{ - std::shared_ptr rateLimiter = _getRateLimiter(url); - if (rateLimiter == nullptr) { - return {RequestResult::InvalidURL, 0, 0}; - } - - if (!rateLimiter->tryRequest()) { - return {RequestResult::RateLimited, 0, 0}; - } - - HTTPClient client; - _setupClient(client); - - return _doGetStream(client, url, headers, acceptedCodes, rateLimiter, contentLengthCallback, downloadCallback, timeoutMs); -} - -HTTP::Response HTTP::GetString(std::string_view url, const std::map& headers, tcb::span acceptedCodes, uint32_t timeoutMs) -{ - std::string result; - - auto allocator = [&result](std::size_t contentLength) { - result.reserve(contentLength); - return true; - }; - auto writer = [&result](std::size_t offset, const uint8_t* data, std::size_t len) { - result.append(reinterpret_cast(data), len); - return true; - }; - - auto response = Download(url, headers, allocator, writer, acceptedCodes, timeoutMs); - if (response.result != RequestResult::Success) { - return {response.result, response.code, {}}; - } - - return {response.result, response.code, result}; -} diff --git a/src/http/JsonAPI.cpp b/src/http/JsonAPI.cpp index a83a4c09..0b563c78 100644 --- a/src/http/JsonAPI.cpp +++ b/src/http/JsonAPI.cpp @@ -2,68 +2,60 @@ #include "Common.h" #include "config/Config.h" +#include "http/HTTPClient.h" #include "util/StringUtils.h" using namespace OpenShock; -HTTP::Response HTTP::JsonAPI::LinkAccount(std::string_view accountLinkCode) +HTTP::JsonResponse HTTP::JsonAPI::LinkAccount(std::string_view accountLinkCode) { std::string domain; if (!Config::GetBackendDomain(domain)) { - return {HTTP::RequestResult::InternalError, 0, {}}; + return HTTPError::InternalError; } char uri[OPENSHOCK_URI_BUFFER_SIZE]; sprintf(uri, "https://%s/1/device/pair/%.*s", domain.c_str(), accountLinkCode.length(), accountLinkCode.data()); - return HTTP::GetJSON( - uri, - { - {"Accept", "application/json"} - }, - Serialization::JsonAPI::ParseAccountLinkJsonResponse, - std::array {200} - ); + HTTP::HTTPClient client(uri); + + client.SetHeader("Accept", "application/json"); + + return client.GetJson(Serialization::JsonAPI::ParseAccountLinkJsonResponse); } -HTTP::Response HTTP::JsonAPI::GetHubInfo(std::string_view hubToken) +HTTP::JsonResponse HTTP::JsonAPI::GetHubInfo(const char* hubToken) { std::string domain; if (!Config::GetBackendDomain(domain)) { - return {HTTP::RequestResult::InternalError, 0, {}}; + return HTTPError::InternalError; } char uri[OPENSHOCK_URI_BUFFER_SIZE]; sprintf(uri, "https://%s/1/device/self", domain.c_str()); - return HTTP::GetJSON( - uri, - { - { "Accept", "application/json"}, - {"DeviceToken", OpenShock::StringToArduinoString(hubToken)} - }, - Serialization::JsonAPI::ParseHubInfoJsonResponse, - std::array {200} - ); + HTTP::HTTPClient client(uri); + + client.SetHeader("Accept", "application/json"); + client.SetHeader("DeviceToken", hubToken); + + return client.GetJson(Serialization::JsonAPI::ParseHubInfoJsonResponse); } -HTTP::Response HTTP::JsonAPI::AssignLcg(std::string_view hubToken) +HTTP::JsonResponse HTTP::JsonAPI::AssignLcg(const char* hubToken) { std::string domain; if (!Config::GetBackendDomain(domain)) { - return {HTTP::RequestResult::InternalError, 0, {}}; + return HTTPError::InternalError; } char uri[OPENSHOCK_URI_BUFFER_SIZE]; sprintf(uri, "https://%s/2/device/assignLCG?version=2", domain.c_str()); - return HTTP::GetJSON( - uri, - { - { "Accept", "application/json"}, - {"DeviceToken", OpenShock::StringToArduinoString(hubToken)} - }, - Serialization::JsonAPI::ParseAssignLcgJsonResponse, - std::array {200} - ); + HTTP::HTTPClient client(uri); + + client.SetHeader("Accept", "application/json"); + client.SetHeader("DeviceToken", hubToken); + + return client.GetJson(Serialization::JsonAPI::ParseAssignLcgJsonResponse); } diff --git a/src/http/RateLimiters.cpp b/src/http/RateLimiters.cpp new file mode 100644 index 00000000..5a13b4f5 --- /dev/null +++ b/src/http/RateLimiters.cpp @@ -0,0 +1,49 @@ +#include + +#include "http/RateLimiters.h" + +#include "SimpleMutex.h" +#include "util/DomainUtils.h" + +#include +#include + +static OpenShock::SimpleMutex s_rateLimitsMutex = {}; +static std::unordered_map> s_rateLimits = {}; + +using namespace OpenShock; + +std::shared_ptr _rateLimiterFactory(std::string_view domain) +{ + auto rateLimit = std::make_shared(); + + // Add default limits + rateLimit->addLimit(1000, 5); // 5 per second + rateLimit->addLimit(10 * 1000, 10); // 10 per 10 seconds + + // per-domain limits + if (domain == OPENSHOCK_API_DOMAIN) { + rateLimit->addLimit(60 * 1000, 12); // 12 per minute + rateLimit->addLimit(60 * 60 * 1000, 120); // 120 per hour + } + + return rateLimit; +} + +std::shared_ptr HTTP::RateLimiters::GetRateLimiter(std::string_view url) +{ + auto domain = std::string(DomainUtils::GetDomainFromUrl(url)); + if (domain.empty()) { + return nullptr; + } + + OpenShock::ScopedLock lock__(&s_rateLimitsMutex); + + auto it = s_rateLimits.find(domain); + if (it == s_rateLimits.end()) { + s_rateLimits.emplace(domain, _rateLimiterFactory(domain)); + it = s_rateLimits.find(domain); + } + + return it->second; +} diff --git a/src/main.cpp b/src/main.cpp index 153f4988..49c56aa7 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -19,6 +19,8 @@ const char* const TAG = "main"; #include +#include + #include // Internal setup function, returns true if setup succeeded, false otherwise. @@ -89,11 +91,17 @@ void appSetup() } } +extern const uint8_t* global_ca_crt_bundle_start asm("_binary_certificates_x509_crt_bundle_start"); +extern const uint8_t* global_ca_crt_bundle_end asm("_binary_certificates_x509_crt_bundle_end"); + // Arduino setup function void setup() { ::Serial.begin(115'200); + esp_tls_init_global_ca_store(); + esp_tls_set_global_ca_store(global_ca_crt_bundle_start, static_cast(global_ca_crt_bundle_end - global_ca_crt_bundle_start)); + OpenShock::Config::Init(); if (!OpenShock::Events::Init()) { diff --git a/src/serial/SerialInputHandler.cpp b/src/serial/SerialInputHandler.cpp index 5bd0acca..9748fc16 100644 --- a/src/serial/SerialInputHandler.cpp +++ b/src/serial/SerialInputHandler.cpp @@ -10,7 +10,6 @@ const char* const TAG = "SerialInputHandler"; #include "Core.h" #include "estop/EStopManager.h" #include "FormatHelpers.h" -#include "http/HTTPRequestManager.h" #include "Logging.h" #include "serial/command_handlers/CommandEntry.h" #include "serial/command_handlers/common.h" diff --git a/src/serial/command_handlers/authtoken.cpp b/src/serial/command_handlers/authtoken.cpp index 18c9ade6..ff8b39a8 100644 --- a/src/serial/command_handlers/authtoken.cpp +++ b/src/serial/command_handlers/authtoken.cpp @@ -1,6 +1,7 @@ #include "serial/command_handlers/common.h" #include "config/Config.h" +#include "http/HTTPClient.h" #include "http/JsonAPI.h" #include @@ -18,15 +19,17 @@ void _handleAuthtokenCommand(std::string_view arg, bool isAutomated) { return; } - auto apiResponse = OpenShock::HTTP::JsonAPI::GetHubInfo(arg); - if (apiResponse.code == 401) { + std::string token = std::string(arg); + + auto apiResponse = OpenShock::HTTP::JsonAPI::GetHubInfo(token.c_str()); + + if (apiResponse.StatusCode() == 401) { SERPR_ERROR("Invalid auth token, refusing to save it!"); return; } // If we have some other kind of request fault just set it anyway, we probably arent connected to a network - - bool result = OpenShock::Config::SetBackendAuthToken(std::string(arg)); + bool result = OpenShock::Config::SetBackendAuthToken(std::move(token)); if (result) { SERPR_SUCCESS("Saved config"); diff --git a/src/serial/command_handlers/domain.cpp b/src/serial/command_handlers/domain.cpp index 91185593..7671c9ef 100644 --- a/src/serial/command_handlers/domain.cpp +++ b/src/serial/command_handlers/domain.cpp @@ -1,8 +1,8 @@ #include "serial/command_handlers/common.h" #include "config/Config.h" -#include "http/HTTPRequestManager.h" -#include "serialization/JsonAPI.h" +#include "http/HTTPClient.h" +#include "http/JsonAPI.h" #include @@ -33,21 +33,20 @@ void _handleDomainCommand(std::string_view arg, bool isAutomated) { char uri[OPENSHOCK_URI_BUFFER_SIZE]; sprintf(uri, "https://%.*s/1", arg.length(), arg.data()); - auto resp = OpenShock::HTTP::GetJSON( - uri, - { - {"Accept", "application/json"} - }, - OpenShock::Serialization::JsonAPI::ParseBackendVersionJsonResponse, - std::array {200} - ); - - if (resp.result != OpenShock::HTTP::RequestResult::Success) { - SERPR_ERROR("Tried to connect to \"%.*s\", but failed with status [%d] (%s), refusing to save domain to config", arg.length(), arg.data(), resp.code, resp.ResultToString()); + OpenShock::HTTP::HTTPClient client(uri); + auto response = client.GetJson(OpenShock::Serialization::JsonAPI::ParseBackendVersionJsonResponse); + if (!response.Ok() || response.StatusCode() != 200) { + SERPR_ERROR("Tried to connect to \"%.*s\", but failed with status [%d] (%s), refusing to save domain to config", arg.length(), arg.data(), response.StatusCode(), OpenShock::HTTP::HTTPErrorToString(response.Error())); return; } - OS_LOGI(TAG, "Successfully connected to \"%.*s\", version: %s, commit: %s, current time: %s", arg.length(), arg.data(), resp.data.version.c_str(), resp.data.commit.c_str(), resp.data.currentTime.c_str()); + auto content = response.ReadJson(); + if (content.error != OpenShock::HTTP::HTTPError::None) { + SERPR_ERROR("Tried to read response from backend, but failed (%s), refusing to save domain to config", OpenShock::HTTP::HTTPErrorToString(response.Error())); + return; + } + + OS_LOGI(TAG, "Successfully connected to \"%.*s\", version: %s, commit: %s, current time: %s", arg.length(), arg.data(), content.data.version.c_str(), content.data.commit.c_str(), content.data.currentTime.c_str()); bool result = OpenShock::Config::SetBackendDomain(std::string(arg)); diff --git a/src/serialization/JsonAPI.cpp b/src/serialization/JsonAPI.cpp index 1a412ae2..c62c0d36 100644 --- a/src/serialization/JsonAPI.cpp +++ b/src/serialization/JsonAPI.cpp @@ -2,16 +2,16 @@ const char* const TAG = "JsonAPI"; +#include + #include "Logging.h" #define ESP_LOGJSONE(err, root) OS_LOGE(TAG, "Invalid JSON response (" err "): %s", cJSON_PrintUnformatted(root)) using namespace OpenShock::Serialization; -bool JsonAPI::ParseLcgInstanceDetailsJsonResponse(int code, const cJSON* root, JsonAPI::LcgInstanceDetailsResponse& out) +bool JsonAPI::ParseLcgInstanceDetailsJsonResponse(const cJSON* root, JsonAPI::LcgInstanceDetailsResponse& out) { - (void)code; - if (cJSON_IsObject(root) == 0) { ESP_LOGJSONE("not an object", root); return false; @@ -57,10 +57,8 @@ bool JsonAPI::ParseLcgInstanceDetailsJsonResponse(int code, const cJSON* root, J return true; } -bool JsonAPI::ParseBackendVersionJsonResponse(int code, const cJSON* root, JsonAPI::BackendVersionResponse& out) +bool JsonAPI::ParseBackendVersionJsonResponse(const cJSON* root, JsonAPI::BackendVersionResponse& out) { - (void)code; - if (cJSON_IsObject(root) == 0) { ESP_LOGJSONE("not an object", root); return false; @@ -99,10 +97,8 @@ bool JsonAPI::ParseBackendVersionJsonResponse(int code, const cJSON* root, JsonA return true; } -bool JsonAPI::ParseAccountLinkJsonResponse(int code, const cJSON* root, JsonAPI::AccountLinkResponse& out) +bool JsonAPI::ParseAccountLinkJsonResponse(const cJSON* root, JsonAPI::AccountLinkResponse& out) { - (void)code; - if (cJSON_IsObject(root) == 0) { ESP_LOGJSONE("not an object", root); return false; @@ -120,10 +116,8 @@ bool JsonAPI::ParseAccountLinkJsonResponse(int code, const cJSON* root, JsonAPI: return true; } -bool JsonAPI::ParseHubInfoJsonResponse(int code, const cJSON* root, JsonAPI::HubInfoResponse& out) +bool JsonAPI::ParseHubInfoJsonResponse(const cJSON* root, JsonAPI::HubInfoResponse& out) { - (void)code; - if (cJSON_IsObject(root) == 0) { ESP_LOGJSONE("not an object", root); return false; @@ -213,10 +207,8 @@ bool JsonAPI::ParseHubInfoJsonResponse(int code, const cJSON* root, JsonAPI::Hub return true; } -bool JsonAPI::ParseAssignLcgJsonResponse(int code, const cJSON* root, JsonAPI::AssignLcgResponse& out) +bool JsonAPI::ParseAssignLcgJsonResponse(const cJSON* root, JsonAPI::AssignLcgResponse& out) { - (void)code; - if (cJSON_IsObject(root) == 0) { ESP_LOGJSONE("not an object", root); return false; diff --git a/src/util/DomainUtils.cpp b/src/util/DomainUtils.cpp new file mode 100644 index 00000000..f79cf1e1 --- /dev/null +++ b/src/util/DomainUtils.cpp @@ -0,0 +1,37 @@ +#include "util/DomainUtils.h" + +std::string_view OpenShock::DomainUtils::GetDomainFromUrl(std::string_view url) { + if (url.empty()) { + return {}; + } + + // Remove the protocol eg. "https://api.example.com:443/path" -> "api.example.com:443/path" + auto separator = url.find("://"); + if (separator != std::string_view::npos) { + url.substr(separator + 3); + } + + // Remove the path eg. "api.example.com:443/path" -> "api.example.com:443" + separator = url.find('/'); + if (separator != std::string_view::npos) { + url = url.substr(0, separator); + } + + // Remove the port eg. "api.example.com:443" -> "api.example.com" + separator = url.rfind(':'); + if (separator != std::string_view::npos) { + url = url.substr(0, separator); + } + + // Remove all subdomains eg. "api.example.com" -> "example.com" + separator = url.rfind('.'); + if (separator == std::string_view::npos) { + return url; // E.g. "localhost" + } + separator = url.rfind('.', separator - 1); + if (separator != std::string_view::npos) { + url = url.substr(separator + 1); + } + + return url; +} diff --git a/src/util/ParitionUtils.cpp b/src/util/ParitionUtils.cpp index 67a83229..b6d43f41 100644 --- a/src/util/ParitionUtils.cpp +++ b/src/util/ParitionUtils.cpp @@ -4,11 +4,12 @@ const char* const TAG = "PartitionUtils"; #include "Core.h" #include "Hashing.h" -#include "http/HTTPRequestManager.h" +#include "http/HTTPClient.h" #include "Logging.h" #include "util/HexUtils.h" -bool OpenShock::TryGetPartitionHash(const esp_partition_t* partition, char (&hash)[65]) { +bool OpenShock::TryGetPartitionHash(const esp_partition_t* partition, char (&hash)[65]) +{ uint8_t buffer[32]; esp_err_t err = esp_partition_get_sha256(partition, buffer); if (err != ESP_OK) { @@ -22,7 +23,8 @@ bool OpenShock::TryGetPartitionHash(const esp_partition_t* partition, char (&has return true; } -bool OpenShock::FlashPartitionFromUrl(const esp_partition_t* partition, std::string_view remoteUrl, const uint8_t (&remoteHash)[32], std::function progressCallback) { +bool OpenShock::FlashPartitionFromUrl(const esp_partition_t* partition, const char* remoteUrl, const uint8_t (&remoteHash)[32], std::function progressCallback) +{ OpenShock::SHA256 sha256; if (!sha256.begin()) { OS_LOGE(TAG, "Failed to initialize SHA256 hash"); @@ -31,27 +33,8 @@ bool OpenShock::FlashPartitionFromUrl(const esp_partition_t* partition, std::str std::size_t contentLength = 0; std::size_t contentWritten = 0; - int64_t lastProgress = 0; + int64_t lastProgress = 0; - auto sizeValidator = [partition, &contentLength, progressCallback, &lastProgress](std::size_t size) -> bool { - if (size > partition->size) { - OS_LOGE(TAG, "Remote partition binary is too large"); - return false; - } - - // Erase app partition. - if (esp_partition_erase_range(partition, 0, partition->size) != ESP_OK) { - OS_LOGE(TAG, "Failed to erase partition in preparation for update"); - return false; - } - - contentLength = size; - - lastProgress = OpenShock::millis(); - progressCallback(0, contentLength, 0.0f); - - return true; - }; auto dataWriter = [partition, &sha256, &contentLength, &contentWritten, progressCallback, &lastProgress](std::size_t offset, const uint8_t* data, std::size_t length) -> bool { if (esp_partition_write(partition, offset, data, length) != ESP_OK) { OS_LOGE(TAG, "Failed to write to partition"); @@ -75,23 +58,37 @@ bool OpenShock::FlashPartitionFromUrl(const esp_partition_t* partition, std::str }; // Start streaming binary to app partition. - auto appBinaryResponse = OpenShock::HTTP::Download( - remoteUrl, - { - {"Accept", "application/octet-stream"} - }, - sizeValidator, - dataWriter, - std::array {200, 304}, - 180'000 - ); // 3 minutes - if (appBinaryResponse.result != OpenShock::HTTP::RequestResult::Success) { - OS_LOGE(TAG, "Failed to download remote partition binary: [%u]", appBinaryResponse.code); + HTTP::HTTPClient client(remoteUrl, 180'000); // 3 minutes timeout + auto response = client.Get(); + if (!response.Ok() || (response.StatusCode() != 200 && response.StatusCode() != 304)) { + OS_LOGE(TAG, "Failed to download remote partition binary: [%u]", response.StatusCode()); + return false; + } + + if (response.ContentLength() > partition->size) { + OS_LOGE(TAG, "Remote partition binary is too large"); + return false; + } + + // Erase app partition. + if (esp_partition_erase_range(partition, 0, partition->size) != ESP_OK) { + OS_LOGE(TAG, "Failed to erase partition in preparation for update"); + return false; + } + + contentLength = response.ContentLength(); + + lastProgress = OpenShock::millis(); + progressCallback(0, contentLength, 0.0f); + + auto streamResult = response.ReadStream(dataWriter); + if (streamResult.error != HTTP::HTTPError::None) { + OS_LOGE(TAG, "Failed to download partition: %s", HTTP::HTTPErrorToString(streamResult.error)); return false; } progressCallback(contentLength, contentLength, 1.0f); - OS_LOGD(TAG, "Wrote %u bytes to partition", appBinaryResponse.data); + OS_LOGD(TAG, "Wrote %u bytes to partition", contentLength); std::array localHash; if (!sha256.finish(localHash)) {