diff --git a/Includes/Request.h b/Includes/Request.h index 7815052..f0f5111 100644 --- a/Includes/Request.h +++ b/Includes/Request.h @@ -29,6 +29,6 @@ public: ParameterValue &Header(const std::string &key) { return Headers[key]; } std::unordered_map Parameters; std::unordered_map Headers; - std::vector URLParameters; + std::unordered_map URLParameters; }; } // namespace VWeb \ No newline at end of file diff --git a/Includes/Router.h b/Includes/Router.h index 29da6de..cbc53d8 100644 --- a/Includes/Router.h +++ b/Includes/Router.h @@ -3,20 +3,38 @@ #include "Route.h" #include -#include +#include namespace VWeb { typedef std::function RouteFunction; typedef std::function()> RouteInstaniateFunction; +struct RouteTree { + void Add(const std::string &path, uint32_t allowedMethods, + RouteInstaniateFunction instaniate); + Ref Find(const std::string &path, Request &request); + +protected: + struct Node { + explicit Node(const uint64_t id) : ID(id) {} + std::unordered_map> Children{}; + uint64_t ID{0}; + }; + Node Root{0}; + uint64_t m_NodeID = 1; + struct RouteInstance { + uint32_t AllowedMethods = HttpMethod::OPTIONS | HttpMethod::HEAD; + RouteInstaniateFunction Instaniate; + }; + std::unordered_map m_Routes; +}; + class Router { public: Router(); - void DeleteRoute(const std::string &name); Ref HandleRoute(Ref &request); Ref FindRoute(Ref &request); - static void AddToArgs(Ref &request, std::vector &items); template void Register(const std::string &endpoint, HttpMethod allowedMethod) { @@ -28,23 +46,18 @@ public: Register(const std::string &endpoint, uint32_t allowedMethods = static_cast(HttpMethod::ALL)) { static_assert(std::is_base_of_v, "must be a Route"); - allowedMethods |= HttpMethod::HEAD | HttpMethod::OPTIONS; - m_Routes[endpoint] = {.AllowedMethods = allowedMethods, - .Instaniate = [] { return std::make_shared(); }}; + m_Tree.Add(endpoint, + allowedMethods | HttpMethod::HEAD | HttpMethod::OPTIONS, + [] { return std::make_shared(); }); } - void Get(const std::string &path, RouteFunction); - void Post(const std::string &path, RouteFunction); - void Put(const std::string &path, RouteFunction); - void Patch(const std::string &path, RouteFunction); - void Delete(const std::string &path, RouteFunction); + void Get(const std::string &path, const RouteFunction &); + void Post(const std::string &path, const RouteFunction &); + void Put(const std::string &path, const RouteFunction &); + void Patch(const std::string &path, const RouteFunction &); + void Delete(const std::string &path, const RouteFunction &); private: - struct RouteInstance { - uint32_t AllowedMethods = HttpMethod::OPTIONS | HttpMethod::HEAD; - RouteInstaniateFunction Instaniate; - }; - std::unordered_map m_Routes; - std::unordered_map m_FunctionRoutes; + RouteTree m_Tree{}; }; } // namespace VWeb \ No newline at end of file diff --git a/Includes/Server.h b/Includes/Server.h index e8014a0..605dd19 100644 --- a/Includes/Server.h +++ b/Includes/Server.h @@ -21,7 +21,6 @@ public: void Stop() { m_IsExit = true; } Ref &GetRouter() { return m_Router; } Ref &GetServerConfig() { return m_ServerConfig; } - void RemoveRoute(const std::string &path) const; Ref &Middleware(); diff --git a/Source/RequestHandler.cpp b/Source/RequestHandler.cpp index c76436b..9cb5f60 100644 --- a/Source/RequestHandler.cpp +++ b/Source/RequestHandler.cpp @@ -49,10 +49,9 @@ void ParseParameterString(Request &req, const std::string &toParse) { } } -std::string GetPostBody(const std::string& originalBody) -{ +std::string GetPostBody(const std::string &originalBody) { auto body = String::Split(originalBody, "\r\n\r\n", 1); - if (body.size() > 1 && ! body[body.size() - 1].empty()) + if (body.size() > 1 && !body[body.size() - 1].empty()) return String::TrimCopy(String::UrlDecode(body[body.size() - 1])); return {}; } @@ -62,7 +61,7 @@ void ParseParameters(Request &request, RequestHandler &requestHandler) { size_t hasURLParameters = uri.find('?'); if (hasURLParameters != std::string::npos) { ParseParameterString(request, uri.substr(hasURLParameters + 1)); - request.URI = uri.substr (0, hasURLParameters); + request.URI = uri.substr(0, hasURLParameters); } if (request.Method == HttpMethod::HEAD || request.Method == HttpMethod::GET || diff --git a/Source/Router.cpp b/Source/Router.cpp index b08d8f9..08ab579 100644 --- a/Source/Router.cpp +++ b/Source/Router.cpp @@ -1,6 +1,7 @@ #include "Includes/VWeb.h" #include "StringUtils.h" +#include #include #include @@ -30,14 +31,59 @@ public: } }; -Router::Router() { Register("@"); } - -void Router::DeleteRoute(const std::string &name) { - if (m_Routes.contains(name)) { - m_Routes.erase(name); +void RouteTree::Add(const std::string &path, uint32_t allowedMethods, + RouteInstaniateFunction instaniate) { + auto segments = String::Split(path, "/"); + auto node = &Root; + for (const auto &segment : segments) { + if (segment.empty()) + continue; + if (!node->Children.contains(segment)) { + node->Children[segment] = std::make_unique(m_NodeID++); + } + node = node->Children.at(segment).get(); } + m_Routes[node->ID] = {.AllowedMethods = allowedMethods, + .Instaniate = std::move(instaniate)}; } +Ref RouteTree::Find(const std::string &path, Request &request) { + auto segments = String::Split(path, "/"); + auto node = &Root; + for (const auto &segment : segments) { + if (segment.empty()) + continue; + if (auto it = node->Children.find(segment); it != node->Children.end()) { + node = it->second.get(); + } else { + // Arguments... + bool foundParameter = false; + for (auto &[key, child] : node->Children) { + if (key[0] == ':') { + node = child.get(); + foundParameter = true; + request.URLParameters[key.substr(1)] = segment; + break; + } + } + if (!foundParameter) { + request.URLParameters = {}; + return nullptr; + } + } + } + if (m_Routes.contains(node->ID)) { + const auto &instance = m_Routes[node->ID]; + auto ref = instance.Instaniate(); + ref->SetAllowedMethods(instance.AllowedMethods); + return ref; + } + request.URLParameters = {}; + return nullptr; +} + +Router::Router() { Register("@"); } + static void HandleOptions(Ref &response, uint32_t allowedMethods) { std::stringstream str{}; bool isFirst = true; @@ -60,21 +106,18 @@ Ref Router::HandleRoute(Ref &request) { response->Method = request->Method; if (!route) { - // Lets check if we can run it through functions routes.. - const auto it = m_FunctionRoutes.find( - s_HttpMethodToString[request->Method] + request->URI); - if (it != m_FunctionRoutes.end()) { - it->second(*request, *response); + route = m_Tree.Find(s_HttpMethodToString[request->Method] + request->URI, + *request); + if (!route) { + response->SetStatus(HttpStatusCode::NotFound); + m_Tree.Find("@", *request)->Execute(*request, *response); return response; } - response->SetStatus(HttpStatusCode::NotFound); - m_Routes["@"].Instaniate()->Execute(*request, *response); - return response; } if (!route->IsAllowed(*request)) { response->SetStatus(HttpStatusCode::Forbidden); - m_Routes["@"].Instaniate()->Execute(*request, *response); + m_Tree.Find("@", *request)->Execute(*request, *response); return response; } @@ -85,77 +128,56 @@ Ref Router::HandleRoute(Ref &request) { if (!route->Execute(*request, *response)) { std::string rKey = "@" + std::to_string(to_underlying(response->Status)); - m_Routes.contains(rKey) - ? m_Routes[rKey].Instaniate()->Execute(*request, *response) - : m_Routes["@"].Instaniate()->Execute(*request, *response); + auto r = m_Tree.Find(rKey, *request); + if (r) { + r->Execute(*request, *response); + } else { + m_Tree.Find("@", *request)->Execute(*request, *response); + } } return response; } -static Ref Instaniate(const RouteInstaniateFunction &func, - uint32_t allowedMethods) { - auto ref = func(); - ref->SetAllowedMethods(allowedMethods); - return ref; -} - Ref Router::FindRoute(Ref &request) { const auto &url = request->URI; - if (url.starts_with("@")) return nullptr; + return m_Tree.Find(url, *request); +} - { - if (const auto it = m_Routes.find(url); - it != m_Routes.end() && it->second.AllowedMethods & request->Method) { - return Instaniate(it->second.Instaniate, it->second.AllowedMethods); - } +struct InlineRoute : Route { + explicit InlineRoute(RouteFunction function) : Func(std::move(function)) {} + RouteFunction Func; + bool Execute(Request &request, Response &response) override { + Func(request, response); + return true; } + bool IsAllowed(Request &request) override { return true; } +}; - auto split = String::Split(url, "/"); - if (split.size() > 1) { - AddToArgs(request, split); - while (split.size() > 1) { - std::string nUrl = String::Join(split, "/"); - if (auto it = m_Routes.find(url); - it != m_Routes.end() && it->second.AllowedMethods & request->Method) { - return Instaniate(it->second.Instaniate, it->second.AllowedMethods); - } - AddToArgs(request, split); - } - } - { - if (const auto it = m_Routes.find("/"); - it != m_Routes.end() && it->second.AllowedMethods & request->Method) { - return Instaniate(it->second.Instaniate, it->second.AllowedMethods); - } - } - return nullptr; +void Router::Get(const std::string &path, const RouteFunction &func) { + m_Tree.Add(s_HttpMethodToString[HttpMethod::GET] + path, + (uint32_t)HttpMethod::GET, + [func] { return std::make_shared(func); }); } - -void Router::AddToArgs(Ref &request, std::vector &items) { - request->URLParameters.push_back(items[items.size() - 1]); - items.pop_back(); +void Router::Post(const std::string &path, const RouteFunction &func) { + m_Tree.Add(s_HttpMethodToString[HttpMethod::POST] + path, + (uint32_t)HttpMethod::POST, + [func] { return std::make_shared(func); }); } - -void Router::Get(const std::string &path, RouteFunction func) { - m_FunctionRoutes[s_HttpMethodToString[HttpMethod::GET] + path] = - std::move(func); +void Router::Put(const std::string &path, const RouteFunction &func) { + m_Tree.Add(s_HttpMethodToString[HttpMethod::PUT] + path, + (uint32_t)HttpMethod::PUT, + [func] { return std::make_shared(func); }); } -void Router::Post(const std::string &path, RouteFunction func) { - m_FunctionRoutes[s_HttpMethodToString[HttpMethod::POST] + path] = - std::move(func); +void Router::Patch(const std::string &path, const RouteFunction &func) { + m_Tree.Add(s_HttpMethodToString[HttpMethod::PATCH] + path, + (uint32_t)HttpMethod::PATCH, + [func] { return std::make_shared(func); }); } -void Router::Put(const std::string &path, RouteFunction func) { - m_FunctionRoutes[s_HttpMethodToString[HttpMethod::PUT] + path] = - std::move(func); -} -void Router::Patch(const std::string &path, RouteFunction func) { - m_FunctionRoutes[s_HttpMethodToString[HttpMethod::PATCH] + path] = - std::move(func); -} -void Router::Delete(const std::string &path, RouteFunction func) { - m_FunctionRoutes[s_HttpMethodToString[HttpMethod::DELETE] + path] = - std::move(func); +void Router::Delete(const std::string &path, const RouteFunction &func) { + m_Tree.Add(s_HttpMethodToString[HttpMethod::DELETE] + path, + (uint32_t)HttpMethod::DELETE, + [func] { return std::make_shared(func); }); } } // namespace VWeb diff --git a/Source/Server.cpp b/Source/Server.cpp index 177eaaa..725f9fd 100644 --- a/Source/Server.cpp +++ b/Source/Server.cpp @@ -29,9 +29,6 @@ void Server::Start() { fprintf(stdout, "[VWeb] Running Server On: 0.0.0.0:%d\n", m_ServerConfig->Port); } -void Server::RemoveRoute(const std::string &path) const { - m_Router->DeleteRoute(path); -} void Server::Execute() { constexpr size_t MAX_EVENTS = 5000; struct epoll_event events[MAX_EVENTS]; diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index a704415..62ab9a8 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,22 +1,8 @@ cmake_minimum_required(VERSION 3.17) project(VWeb_Example) -set(CMAKE_CXX_STANDARD 20) -set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) find_package(VWeb 1.0 REQUIRED) add_executable(VWeb_Example main.cpp) include_directories(${CMAKE_SOURCE_DIR}/..) - -set(mode Release) -if (CMAKE_BUILD_TYPE STREQUAL "Debug") - set(mode Debug) -endif () -set(vweb_lib ${CMAKE_SOURCE_DIR}/../dist/libVWeb.${mode}.a) - -SET_SOURCE_FILES_PROPERTIES( - main.cpp - PROPERTIES OBJECT_DEPENDS ${vweb_lib} -) - -target_link_libraries(VWeb_Example Threads::Threads ${vweb_lib}) +target_link_libraries(VWeb_Example Threads::Threads VWeb) \ No newline at end of file diff --git a/example/main.cpp b/example/main.cpp index 80bfeb2..7741a7d 100644 --- a/example/main.cpp +++ b/example/main.cpp @@ -2,34 +2,35 @@ class MyCompleteController : public VWeb::Route { public: - bool Get(const VWeb::Request&, VWeb::Response& response) { - response << "MyCompleteController: GET"; + bool Get(VWeb::Request &req, VWeb::Response &response) override { + response << "MyCompleteController: GET:: \r\n\r\nParameters:\r\n\r\n"; + for (auto &[key, value] : req.URLParameters) { + response << key << ": " << value << "\r\n"; + } return true; } - bool Post(const VWeb::Request&, VWeb::Response& response) { + bool Post(VWeb::Request &, VWeb::Response &response) override { response << "MyCompleteController: POST"; return true; } - - bool IsAllowed(const VWeb::Request& request) { - return request.HasHeader("Auth"); - } }; -bool Ping(const VWeb::Request&, VWeb::Response& response) { +bool Ping(const VWeb::Request &, VWeb::Response &response) { response << "Pong"; return true; } int main() { using namespace VWeb; VWeb::Server server; - auto& router = server.GetRouter(); + auto &router = server.GetRouter(); // For debugging and profiling more than 1 thread can be hard. server.GetServerConfig()->WorkerThreads = 1; - router->Get("/test", [](Request&, Response& response) { + router->Get("/test", [](Request &, Response &response) { response << "NICE"; return true; }); router->Get("/ping", &Ping); + router->Register("/auth/:id/", + HttpMethod::GET | HttpMethod::POST); server.Start(); server.Join(); return 0;