5#include <spdlog/spdlog.h>
17 asio::any_io_executor executor,
18 std::unique_ptr<transport::Transport> transport,
19 std::shared_ptr<spdlog::logger> logger)
20 : logger_(logger ? logger : spdlog::default_logger()),
21 executor_(std::move(executor)),
22 transport_(std::move(transport)),
23 dispatcher_(executor_, logger_),
24 endpoint_strand_(asio::make_strand(executor_)) {
28 asio::any_io_executor executor,
29 std::unique_ptr<transport::Transport> transport)
30 -> asio::awaitable<std::expected<std::unique_ptr<RpcEndpoint>,
RpcError>> {
31 auto endpoint = std::make_unique<RpcEndpoint>(executor, std::move(transport));
33 auto start_result =
co_await endpoint->Start();
35 co_return std::unexpected(start_result.error());
38 endpoint->Logger()->debug(
"Client endpoint initialized");
43 Logger()->debug(
"RpcEndpoint starting");
44 if (is_running_.exchange(
true)) {
46 RpcErrorCode::kClientError,
"RPC endpoint is already running");
49 pending_requests_.clear();
52 auto start_result =
co_await transport_->Start();
54 co_return start_result;
58 StartMessageProcessing();
61 co_return std::expected<void, RpcError>{};
65 -> asio::awaitable<std::expected<void, RpcError>> {
68 co_return std::expected<void, RpcError>{};
71 if (message_loop_.valid()) {
72 co_await std::move(message_loop_);
75 co_return std::expected<void, RpcError>{};
79 if (!is_running_.exchange(
false)) {
80 co_return std::expected<void, RpcError>{};
83 co_await asio::post(endpoint_strand_, asio::use_awaitable);
84 cancel_signal_.emit(asio::cancellation_type::all);
86 Logger()->debug(
"Shutting down RPC endpoint");
91 for (
auto &[
id, request] : pending_requests_) {
92 request->Cancel(-32603,
"RPC endpoint shutting down");
94 pending_requests_.clear();
97 if (message_loop_.valid()) {
98 co_await std::move(message_loop_);
102 auto close_result =
co_await transport_->Close();
104 co_return close_result;
111 std::string method, std::optional<nlohmann::json> params)
112 -> asio::awaitable<std::expected<nlohmann::json, RpcError>> {
115 RpcErrorCode::kClientError,
"RPC endpoint is not running");
118 auto request_id = GetNextRequestId();
119 Request request(method, std::move(params), request_id);
120 std::string message = request.
ToJson().dump();
122 Logger()->debug(
"RpcEndpoint sending message: {}", message.substr(0, 70));
123 auto pending_request = std::make_shared<PendingRequest>(endpoint_strand_);
124 asio::post(endpoint_strand_, [
this, request_id, pending_request] {
125 pending_requests_[request_id] = pending_request;
128 auto send_result =
co_await transport_->SendMessage(message);
130 co_return std::unexpected(send_result.error());
133 auto result =
co_await pending_request->GetResult();
134 if (result.contains(
"error")) {
135 auto err = result[
"error"];
137 RpcErrorCode::kClientError, err[
"message"].get<std::string>());
140 co_return result[
"result"];
144 std::string method, std::optional<nlohmann::json> params)
145 -> asio::awaitable<std::expected<void, RpcError>> {
146 Logger()->debug(
"RpcEndpoint sending notification: {}", method);
149 RpcErrorCode::kClientError,
"RpcEndpoint is not running");
152 Request request(method, std::move(params));
153 std::string message = request.
ToJson().dump();
155 Logger()->debug(
"RpcEndpoint sending message: {}", message.substr(0, 70));
156 auto send_result =
co_await transport_->SendMessage(message);
158 co_return send_result;
161 co_return std::expected<void, RpcError>{};
176 return !pending_requests_.empty();
179void RpcEndpoint::StartMessageProcessing() {
180 Logger()->debug(
"RpcEndpoint starting message processing");
181 message_loop_ = asio::co_spawn(
185 "RpcEndpoint starting message processing, is_running_: {}",
187 return this->ProcessMessagesLoop(cancel_signal_.slot());
189 asio::use_awaitable);
193auto RetryDelay(asio::any_io_executor exec) -> asio::awaitable<void> {
194 asio::steady_timer timer(exec, std::chrono::milliseconds(100));
195 co_await timer.async_wait(asio::use_awaitable);
199auto RpcEndpoint::ProcessMessagesLoop(asio::cancellation_slot slot)
200 -> asio::awaitable<void> {
201 auto state =
co_await asio::this_coro::cancellation_state;
203 while (is_running_ && !state.cancelled()) {
204 auto message_result =
co_await transport_->ReceiveMessage();
205 if (!message_result) {
206 Logger()->error(
"Receive error: {}", message_result.error().Message());
207 co_await RetryDelay(executor_);
211 auto handle_result =
co_await HandleMessage(*message_result);
212 if (!handle_result) {
213 Logger()->error(
"Handle error: {}", handle_result.error().Message());
214 co_await RetryDelay(executor_);
221auto IsResponse(
const nlohmann::json &msg) ->
bool {
222 return msg.contains(
"id") &&
223 (msg.contains(
"result") || msg.contains(
"error"));
227auto RpcEndpoint::HandleMessage(std::string message)
228 -> asio::awaitable<std::expected<void, RpcError>> {
229 Logger()->debug(
"RpcEndpoint handling message: {}", message.substr(0, 70));
230 const auto json_message_result =
231 nlohmann::json::parse(message,
nullptr,
false);
232 if (json_message_result.is_discarded()) {
234 RpcErrorCode::kClientError,
"Failed to parse message");
236 const auto &json_message = json_message_result;
238 if (IsResponse(json_message)) {
240 if (!response.has_value()) {
242 RpcErrorCode::kClientError,
"Invalid response");
244 co_return co_await HandleResponse(std::move(response.value()));
247 if (
auto response =
co_await dispatcher_.DispatchRequest(message)) {
248 co_return co_await transport_->SendMessage(*response);
251 co_return std::expected<void, RpcError>{};
254auto RpcEndpoint::HandleResponse(Response response)
255 -> asio::awaitable<std::expected<void, RpcError>> {
256 auto id_opt = response.GetId();
257 if (!id_opt || !std::holds_alternative<int64_t>(*id_opt)) {
259 RpcErrorCode::kClientError,
"Response ID missing or not int64");
262 const auto id = std::get<int64_t>(*id_opt);
264 std::shared_ptr<PendingRequest> request;
267 asio::post(endpoint_strand_, [
this,
id, &request, &found] {
268 auto it = pending_requests_.find(
id);
269 if (it != pending_requests_.end()) {
270 request = it->second;
271 pending_requests_.erase(it);
276 co_await asio::post(
co_await asio::this_coro::executor, asio::use_awaitable);
278 if (!found || !request) {
280 RpcErrorCode::kClientError,
281 "Unknown request ID: " + std::to_string(
id));
284 request->SetResult(response.ToJson());
285 co_return std::expected<void, RpcError>{};
std::function< asio::awaitable< nlohmann::json >( const std::optional< nlohmann::json > &)> MethodCallHandler
void RegisterNotification(const std::string &method, const NotificationHandler &handler)
std::function< asio::awaitable< void >( const std::optional< nlohmann::json > &)> NotificationHandler
void RegisterMethodCall(const std::string &method, const MethodCallHandler &handler)
auto ToJson() const -> nlohmann::json
static auto FromJson(const nlohmann::json &json) -> std::expected< Response, error::RpcError >
static auto CreateClient(asio::any_io_executor executor, std::unique_ptr< transport::Transport > transport) -> asio::awaitable< std::expected< std::unique_ptr< RpcEndpoint >, RpcError > >
RpcEndpoint(asio::any_io_executor executor, std::unique_ptr< transport::Transport > transport, std::shared_ptr< spdlog::logger > logger=nullptr)
auto HasPendingRequests() const -> bool
void RegisterNotification(std::string method, typename Dispatcher::NotificationHandler handler)
auto Shutdown() -> asio::awaitable< std::expected< void, RpcError > >
void RegisterMethodCall(std::string method, typename Dispatcher::MethodCallHandler handler)
auto WaitForShutdown() -> asio::awaitable< std::expected< void, RpcError > >
auto SendMethodCall(std::string method, std::optional< nlohmann::json > params=std::nullopt) -> asio::awaitable< std::expected< nlohmann::json, RpcError > >
auto Start() -> asio::awaitable< std::expected< void, RpcError > >
auto SendNotification(std::string method, std::optional< nlohmann::json > params=std::nullopt) -> asio::awaitable< std::expected< void, RpcError > >
auto Logger() -> std::shared_ptr< spdlog::logger >
static auto UnexpectedFromCode(RpcErrorCode code, std::string message="") -> std::unexpected< RpcError >
auto Ok() -> std::expected< void, RpcError >