From d547bb9f953b54970a01987d0dfae97976e80364 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maurice=20Gr=C3=B6nwoldt?= Date: Sun, 4 Feb 2024 13:52:08 +0100 Subject: [PATCH] Make Router more sexy --- Includes/Http.h | 30 ++++++--- Includes/Response.h | 5 +- Includes/Route.h | 14 ++-- Includes/Router.h | 29 +++++++-- Includes/Server.h | 3 +- Source/Route.cpp | 76 +++------------------- Source/Router.cpp | 153 ++++++++++++++++++++++---------------------- Source/Server.cpp | 5 +- example/main.cpp | 2 +- 9 files changed, 144 insertions(+), 173 deletions(-) diff --git a/Includes/Http.h b/Includes/Http.h index 4169996..73c7383 100644 --- a/Includes/Http.h +++ b/Includes/Http.h @@ -1,4 +1,5 @@ #pragma once +#include namespace VWeb { enum class HttpStatusCode : int { @@ -63,14 +64,25 @@ enum class HttpStatusCode : int { NotExtended = 510, NetworkAuthenticationRequired = 511 }; -enum class HttpMethod { - GET = 0, - HEAD, - OPTIONS, - POST, - PUT, - PATCH, - DELETE, - FALLBACK + +enum class HttpMethod : uint32_t { + GET = 1 << 0, + HEAD = 1 << 1, + OPTIONS = 1 << 2, + POST = 1 << 3, + PUT = 1 << 4, + PATCH = 1 << 5, + DELETE = 1 << 6, + FALLBACK = 1 << 7, + ALL = std::numeric_limits::max() }; + +inline uint32_t operator|(HttpMethod lhs, HttpMethod rhs) { + return static_cast(lhs) | static_cast(rhs); +} + +inline bool operator&(const uint32_t lhs, HttpMethod rhs) { + return lhs & static_cast(rhs); +} + } // namespace VWeb \ No newline at end of file diff --git a/Includes/Response.h b/Includes/Response.h index 983f3d9..296ced2 100644 --- a/Includes/Response.h +++ b/Includes/Response.h @@ -2,9 +2,9 @@ #include "Cookie.h" #include "Http.h" +#include "ParameterValue.h" #include "Session.h" #include "Types.h" -#include "ParameterValue.h" namespace VWeb { class Response { @@ -22,7 +22,8 @@ public: std::string GetResponse(); void SetHeader(const std::string &key, ParameterValue &value); void SetHeader(const std::string &key, const std::string &value); - void AddHeaders(const std::string &key, const std::vector &values); + void AddHeaders(const std::string &key, + const std::vector &values); void AddHeader(const std::string &key, const std::string &value); void SetType(const std::string &type); void AddContent(const std::string &data); diff --git a/Includes/Route.h b/Includes/Route.h index 5f9aa9d..cb41296 100644 --- a/Includes/Route.h +++ b/Includes/Route.h @@ -1,9 +1,10 @@ #pragma once #include "Http.h" -#include #include "Request.h" #include "Response.h" + +#include #include namespace VWeb { @@ -12,23 +13,20 @@ class Route { public: Route() = default; virtual ~Route() = default; - Route(std::initializer_list); virtual bool Execute(Request &request, Response &response); virtual bool Get(Request &request, Response &response); virtual bool Post(Request &request, Response &response); virtual bool Put(Request &request, Response &response); virtual bool Patch(Request &request, Response &response); virtual bool Delete(Request &request, Response &response); - bool Options(Request &request, Response &response); virtual bool Fallback(Request &request, Response &response); - bool SupportsMethod(Request &request); virtual bool IsAllowed(Request &request); - void AllowMethod(HttpMethod method); + void SetAllowedMethods(uint32_t methods); + [[nodiscard]] uint32_t GetAllowedMethods() const; protected: - bool m_AllowAll{true}; - std::vector m_AllowedMethods; + uint32_t m_AllowedMethods = 0; friend Router; }; -} \ No newline at end of file +} // namespace VWeb \ No newline at end of file diff --git a/Includes/Router.h b/Includes/Router.h index 03088a9..29da6de 100644 --- a/Includes/Router.h +++ b/Includes/Router.h @@ -7,23 +7,44 @@ namespace VWeb { typedef std::function RouteFunction; +typedef std::function()> RouteInstaniateFunction; + class Router { public: Router(); - void AddRoute(const std::string &name, const Ref &route); - Ref &GetRoute(const std::string &name); void DeleteRoute(const std::string &name); Ref HandleRoute(Ref &request); Ref FindRoute(Ref &request); static void AddToArgs(Ref &request, std::vector &items); -public: + template + void Register(const std::string &endpoint, HttpMethod allowedMethod) { + Register(endpoint, static_cast(allowedMethod)); + } + + template + void + Register(const std::string &endpoint, + uint32_t allowedMethods = static_cast(HttpMethod::ALL)) { + static_assert(std::is_base_of_v, "must be a Route"); + allowedMethods |= HttpMethod::HEAD | HttpMethod::OPTIONS; + m_Routes[endpoint] = {.AllowedMethods = allowedMethods, + .Instaniate = [] { return std::make_shared(); }}; + } + void Get(const std::string &path, RouteFunction); void Post(const std::string &path, RouteFunction); void Put(const std::string &path, RouteFunction); void Patch(const std::string &path, RouteFunction); void Delete(const std::string &path, RouteFunction); - std::unordered_map> m_Routes; + +private: + struct RouteInstance { + uint32_t AllowedMethods = HttpMethod::OPTIONS | HttpMethod::HEAD; + RouteInstaniateFunction Instaniate; + }; + std::unordered_map m_Routes; + std::unordered_map m_FunctionRoutes; }; } // namespace VWeb \ No newline at end of file diff --git a/Includes/Server.h b/Includes/Server.h index 52675dd..e8014a0 100644 --- a/Includes/Server.h +++ b/Includes/Server.h @@ -21,8 +21,7 @@ public: void Stop() { m_IsExit = true; } Ref &GetRouter() { return m_Router; } Ref &GetServerConfig() { return m_ServerConfig; } - void AddRoute(const std::string &path, const Ref &route); - void RemoveRoute(const std::string &path); + void RemoveRoute(const std::string &path) const; Ref &Middleware(); diff --git a/Source/Route.cpp b/Source/Route.cpp index c826ca2..d6e793c 100644 --- a/Source/Route.cpp +++ b/Source/Route.cpp @@ -4,84 +4,26 @@ namespace VWeb { -#define stringify(name) {name, std::string(#name).replace(0,12,"")} -static std::unordered_map s_HttpMethodToString = { - stringify(HttpMethod::HEAD), - stringify(HttpMethod::GET), - stringify(HttpMethod::OPTIONS), - stringify(HttpMethod::POST), - stringify(HttpMethod::PUT), - stringify(HttpMethod::PATCH), - stringify(HttpMethod::DELETE), - stringify(HttpMethod::FALLBACK) -}; -#undef stringify - -Route::Route(std::initializer_list methods) { - m_AllowedMethods = methods; - m_AllowAll = false; - m_AllowedMethods.push_back(HttpMethod::HEAD); - m_AllowedMethods.push_back(HttpMethod::OPTIONS); -} bool Route::Execute(Request &request, Response &response) { switch (request.Method) { case HttpMethod::GET: case HttpMethod::HEAD: return Get(request, response); case HttpMethod::POST: return Post(request, response); case HttpMethod::PUT: return Put(request, response); - case HttpMethod::OPTIONS: return Options(request, response); case HttpMethod::PATCH: return Patch(request, response); case HttpMethod::DELETE: return Delete(request, response); default: return Fallback(request, response); } } -bool Route::Get(Request &request, Response &response) { - return true; -} -bool Route::Post(Request &request, Response &response) { - return true; -} -bool Route::Put(Request &request, Response &response) { - return true; -} -bool Route::Patch(Request &request, Response &response) { - return true; -} -bool Route::Delete(Request &request, Response &response) { - return true; -} -bool Route::Options(Request &request, Response &response) { - std::stringstream str{}; - bool isFirst = true; - if (m_AllowAll) { - for (auto &[key, value] : s_HttpMethodToString) { - if (!isFirst) - str << ", "; - str << value; - isFirst = false; - } - } else { - for (auto &method : m_AllowedMethods) { - if (!isFirst) - str << ", "; - str << s_HttpMethodToString[method]; - isFirst = false; - } - } - response.SetHeader("Allow", str.str()); - return true; -} -bool Route::Fallback(Request &request, Response &response) { - return true; -} +bool Route::Get(Request &request, Response &response) { return true; } +bool Route::Post(Request &request, Response &response) { return true; } +bool Route::Put(Request &request, Response &response) { return true; } +bool Route::Patch(Request &request, Response &response) { return true; } +bool Route::Delete(Request &request, Response &response) { return true; } +bool Route::Fallback(Request &request, Response &response) { return true; } bool Route::IsAllowed(Request &request) { return true; } -bool Route::SupportsMethod(Request &request) { - return m_AllowAll || - std::find(m_AllowedMethods.begin(), m_AllowedMethods.end(), - request.Method) != m_AllowedMethods.end(); -} -void Route::AllowMethod(HttpMethod method) { - if (std::find(m_AllowedMethods.begin(), m_AllowedMethods.end(), method) == m_AllowedMethods.end()) - m_AllowedMethods.push_back(method); +void Route::SetAllowedMethods(const uint32_t methods) { + m_AllowedMethods = methods; } +uint32_t Route::GetAllowedMethods() const { return m_AllowedMethods; } } // namespace VWeb diff --git a/Source/Router.cpp b/Source/Router.cpp index 8fba206..b08d8f9 100644 --- a/Source/Router.cpp +++ b/Source/Router.cpp @@ -6,6 +6,15 @@ namespace VWeb { +#define stringify(name) \ + { name, std::string(#name).replace(0, 12, "") } +static std::unordered_map s_HttpMethodToString = { + stringify(HttpMethod::HEAD), stringify(HttpMethod::GET), + stringify(HttpMethod::OPTIONS), stringify(HttpMethod::POST), + stringify(HttpMethod::PUT), stringify(HttpMethod::PATCH), + stringify(HttpMethod::DELETE), stringify(HttpMethod::FALLBACK)}; +#undef stringify + template constexpr auto to_underlying(E e) noexcept { return static_cast>(e); } @@ -21,40 +30,7 @@ public: } }; -class InstanceHandleRoute : public Route { -public: - bool Get(Request &request, Response &response) override { - return GetFunc && GetFunc(request, response); - } - bool Post(Request &request, Response &response) override { - return PostFunc && PostFunc(request, response); - } - bool Put(Request &request, Response &response) override { - return PutFunc && PutFunc(request, response); - } - bool Patch(Request &request, Response &response) override { - return PatchFunc && PatchFunc(request, response); - } - bool Delete(Request &request, Response &response) override { - return DeleteFunc && DeleteFunc(request, response); - } - -protected: - RouteFunction GetFunc{nullptr}; - RouteFunction PostFunc{nullptr}; - RouteFunction PutFunc{nullptr}; - RouteFunction PatchFunc{nullptr}; - RouteFunction DeleteFunc{nullptr}; - friend Router; -}; - -Router::Router() { m_Routes["@"] = CreateRef(); } - -void Router::AddRoute(const std::string &name, const Ref &route) { - m_Routes[name] = route; -} - -Ref &Router::GetRoute(const std::string &name) { return m_Routes[name]; } +Router::Router() { Register("@"); } void Router::DeleteRoute(const std::string &name) { if (m_Routes.contains(name)) { @@ -62,59 +38,97 @@ void Router::DeleteRoute(const std::string &name) { } } +static void HandleOptions(Ref &response, uint32_t allowedMethods) { + std::stringstream str{}; + bool isFirst = true; + for (auto &[key, value] : s_HttpMethodToString) { + if (allowedMethods & static_cast(key)) { + if (!isFirst) + str << ", "; + str << value; + isFirst = false; + } + } + response->SetHeader("Allow", str.str()); +} + Ref Router::HandleRoute(Ref &request) { auto response = CreateRef(); auto route = FindRoute(request); response->CookieData = request->CookieData; response->SessionData = request->SessionData; response->Method = request->Method; + if (!route) { + // Lets check if we can run it through functions routes.. + const auto it = m_FunctionRoutes.find( + s_HttpMethodToString[request->Method] + request->URI); + if (it != m_FunctionRoutes.end()) { + it->second(*request, *response); + return response; + } response->SetStatus(HttpStatusCode::NotFound); - m_Routes["@"]->Execute(*request, *response); + m_Routes["@"].Instaniate()->Execute(*request, *response); return response; } if (!route->IsAllowed(*request)) { response->SetStatus(HttpStatusCode::Forbidden); - m_Routes["@"]->Execute(*request, *response); + m_Routes["@"].Instaniate()->Execute(*request, *response); + return response; + } + + if (request->Method == HttpMethod::OPTIONS) { + HandleOptions(response, route->GetAllowedMethods()); return response; } if (!route->Execute(*request, *response)) { std::string rKey = "@" + std::to_string(to_underlying(response->Status)); - m_Routes.contains(rKey) ? m_Routes[rKey]->Execute(*request, *response) - : m_Routes["@"]->Execute(*request, *response); + m_Routes.contains(rKey) + ? m_Routes[rKey].Instaniate()->Execute(*request, *response) + : m_Routes["@"].Instaniate()->Execute(*request, *response); } return response; } +static Ref Instaniate(const RouteInstaniateFunction &func, + uint32_t allowedMethods) { + auto ref = func(); + ref->SetAllowedMethods(allowedMethods); + return ref; +} + Ref Router::FindRoute(Ref &request) { - auto &url = request->URI; - if (m_Routes.contains(url.data())) { - auto &route = m_Routes.at(url.data()); - if (route->SupportsMethod(*request)) - return route; - } + const auto &url = request->URI; + if (url.starts_with("@")) return nullptr; - auto split = String::Split(url.data(), "/"); + { + if (const auto it = m_Routes.find(url); + it != m_Routes.end() && it->second.AllowedMethods & request->Method) { + return Instaniate(it->second.Instaniate, it->second.AllowedMethods); + } + } + + auto split = String::Split(url, "/"); if (split.size() > 1) { AddToArgs(request, split); while (split.size() > 1) { std::string nUrl = String::Join(split, "/"); - if (m_Routes.contains(nUrl)) { - auto &route = m_Routes[nUrl]; - if (route->SupportsMethod(*request)) - return route; + if (auto it = m_Routes.find(url); + it != m_Routes.end() && it->second.AllowedMethods & request->Method) { + return Instaniate(it->second.Instaniate, it->second.AllowedMethods); } AddToArgs(request, split); } } - if (m_Routes.contains("/")) { - auto &route = m_Routes["/"]; - if (route->SupportsMethod(*request)) - return route; + { + if (const auto it = m_Routes.find("/"); + it != m_Routes.end() && it->second.AllowedMethods & request->Method) { + return Instaniate(it->second.Instaniate, it->second.AllowedMethods); + } } return nullptr; } @@ -124,37 +138,24 @@ void Router::AddToArgs(Ref &request, std::vector &items) { items.pop_back(); } -InstanceHandleRoute* GetOrNull(const std::string& path, Router* router) { - InstanceHandleRoute* route; - if (!router->m_Routes.contains(path)) { - router->AddRoute(path, CreateRef()); - } - route = dynamic_cast(router->m_Routes[path].get()); - return route; -} void Router::Get(const std::string &path, RouteFunction func) { - auto* route = GetOrNull(path, this); - if (route) - route->GetFunc = std::move(func); + m_FunctionRoutes[s_HttpMethodToString[HttpMethod::GET] + path] = + std::move(func); } void Router::Post(const std::string &path, RouteFunction func) { - auto* route = GetOrNull(path, this); - if (route) - route->PostFunc = std::move(func); + m_FunctionRoutes[s_HttpMethodToString[HttpMethod::POST] + path] = + std::move(func); } void Router::Put(const std::string &path, RouteFunction func) { - auto* route = GetOrNull(path, this); - if (route) - route->PutFunc = std::move(func); + m_FunctionRoutes[s_HttpMethodToString[HttpMethod::PUT] + path] = + std::move(func); } void Router::Patch(const std::string &path, RouteFunction func) { - auto* route = GetOrNull(path, this); - if (route) - route->PatchFunc = std::move(func); + m_FunctionRoutes[s_HttpMethodToString[HttpMethod::PATCH] + path] = + std::move(func); } void Router::Delete(const std::string &path, RouteFunction func) { - auto* route = GetOrNull(path, this); - if (route) - route->DeleteFunc = std::move(func); + m_FunctionRoutes[s_HttpMethodToString[HttpMethod::DELETE] + path] = + std::move(func); } } // namespace VWeb diff --git a/Source/Server.cpp b/Source/Server.cpp index 6f0bf2e..177eaaa 100644 --- a/Source/Server.cpp +++ b/Source/Server.cpp @@ -29,10 +29,7 @@ void Server::Start() { fprintf(stdout, "[VWeb] Running Server On: 0.0.0.0:%d\n", m_ServerConfig->Port); } -void Server::AddRoute(const std::string &path, const Ref &route) { - m_Router->AddRoute(path, route); -} -void Server::RemoveRoute(const std::string &path) { +void Server::RemoveRoute(const std::string &path) const { m_Router->DeleteRoute(path); } void Server::Execute() { diff --git a/example/main.cpp b/example/main.cpp index c39a43c..80bfeb2 100644 --- a/example/main.cpp +++ b/example/main.cpp @@ -25,7 +25,7 @@ int main() { auto& router = server.GetRouter(); // For debugging and profiling more than 1 thread can be hard. server.GetServerConfig()->WorkerThreads = 1; - router->Get("/test", [](const Request&, Response& response) { + router->Get("/test", [](Request&, Response& response) { response << "NICE"; return true; });