#include "Includes/VWeb.h" #include "StringUtils.h" #include #include #include namespace VWeb { #define stringify(name) \ { name, std::string(#name).replace(0, 12, "") } static std::unordered_map s_HttpMethodToString = { stringify(HttpMethod::HEAD), stringify(HttpMethod::GET), stringify(HttpMethod::OPTIONS), stringify(HttpMethod::POST), stringify(HttpMethod::PUT), stringify(HttpMethod::PATCH), stringify(HttpMethod::DELETE), stringify(HttpMethod::FALLBACK)}; #undef stringify template constexpr auto to_underlying(E e) noexcept { return static_cast>(e); } class ErrorRoute : public Route { public: bool Execute(Request &request, Response &response) override { response.Reset(); response << "Unhandled Error: Status " << std::to_string(to_underlying(response.Status)); response.SetType("text/plain"); return true; } }; void RouteTree::Add(const std::string &path, uint32_t allowedMethods, RouteInstaniateFunction instaniate) { auto segments = String::Split(path, "/"); auto node = &Root; for (const auto &segment : segments) { if (segment.empty()) continue; if (!node->Children.contains(segment)) { node->Children[segment] = std::make_unique(m_NodeID++); } node = node->Children.at(segment).get(); } m_Routes[node->ID] = {.AllowedMethods = allowedMethods, .Instaniate = std::move(instaniate)}; } Ref RouteTree::Find(const std::string &path, Request &request) { auto segments = String::Split(path, "/"); auto node = &Root; for (const auto &segment : segments) { if (segment.empty()) continue; if (auto it = node->Children.find(segment); it != node->Children.end()) { node = it->second.get(); } else { // Arguments... bool foundParameter = false; for (auto &[key, child] : node->Children) { if (key[0] == ':') { node = child.get(); foundParameter = true; request.URLParameters[key.substr(1)] = segment; break; } } if (!foundParameter) { request.URLParameters = {}; return nullptr; } } } if (m_Routes.contains(node->ID)) { const auto &instance = m_Routes[node->ID]; auto ref = instance.Instaniate(); ref->SetAllowedMethods(instance.AllowedMethods); return ref; } request.URLParameters = {}; return nullptr; } Router::Router() { Register("@"); } static void HandleOptions(Ref &response, uint32_t allowedMethods) { std::stringstream str{}; bool isFirst = true; for (auto &[key, value] : s_HttpMethodToString) { if (allowedMethods & static_cast(key)) { if (!isFirst) str << ", "; str << value; isFirst = false; } } response->SetHeader("Allow", str.str()); } Ref Router::HandleRoute(Ref &request) { auto response = CreateRef(); auto route = FindRoute(request); response->CookieData = request->CookieData; response->SessionData = request->SessionData; response->Method = request->Method; if (!route) { route = m_Tree.Find(s_HttpMethodToString[request->Method] + request->URI, *request); if (!route) { response->SetStatus(HttpStatusCode::NotFound); m_Tree.Find("@", *request)->Execute(*request, *response); return response; } } if (!route->IsAllowed(*request)) { response->SetStatus(HttpStatusCode::Forbidden); m_Tree.Find("@", *request)->Execute(*request, *response); return response; } if (request->Method == HttpMethod::OPTIONS) { HandleOptions(response, route->GetAllowedMethods()); return response; } if (!route->Execute(*request, *response)) { std::string rKey = "@" + std::to_string(to_underlying(response->Status)); auto r = m_Tree.Find(rKey, *request); if (r) { r->Execute(*request, *response); } else { m_Tree.Find("@", *request)->Execute(*request, *response); } } return response; } Ref Router::FindRoute(Ref &request) { const auto &url = request->URI; if (url.starts_with("@")) return nullptr; return m_Tree.Find(url, *request); } struct InlineRoute : Route { explicit InlineRoute(RouteFunction function) : Func(std::move(function)) {} RouteFunction Func; bool Execute(Request &request, Response &response) override { Func(request, response); return true; } bool IsAllowed(Request &request) override { return true; } }; void Router::Get(const std::string &path, const RouteFunction &func) { m_Tree.Add(s_HttpMethodToString[HttpMethod::GET] + path, (uint32_t)HttpMethod::GET, [func] { return std::make_shared(func); }); } void Router::Post(const std::string &path, const RouteFunction &func) { m_Tree.Add(s_HttpMethodToString[HttpMethod::POST] + path, (uint32_t)HttpMethod::POST, [func] { return std::make_shared(func); }); } void Router::Put(const std::string &path, const RouteFunction &func) { m_Tree.Add(s_HttpMethodToString[HttpMethod::PUT] + path, (uint32_t)HttpMethod::PUT, [func] { return std::make_shared(func); }); } void Router::Patch(const std::string &path, const RouteFunction &func) { m_Tree.Add(s_HttpMethodToString[HttpMethod::PATCH] + path, (uint32_t)HttpMethod::PATCH, [func] { return std::make_shared(func); }); } void Router::Delete(const std::string &path, const RouteFunction &func) { m_Tree.Add(s_HttpMethodToString[HttpMethod::DELETE] + path, (uint32_t)HttpMethod::DELETE, [func] { return std::make_shared(func); }); } } // namespace VWeb