VWeb/Source/Router.cpp

184 lines
5.8 KiB
C++

#include "Includes/VWeb.h"
#include "StringUtils.h"
#include <iostream>
#include <type_traits>
#include <utility>
namespace VWeb {
#define stringify(name) \
{ name, std::string(#name).replace(0, 12, "") }
static std::unordered_map<HttpMethod, std::string> s_HttpMethodToString = {
stringify(HttpMethod::HEAD), stringify(HttpMethod::GET),
stringify(HttpMethod::OPTIONS), stringify(HttpMethod::POST),
stringify(HttpMethod::PUT), stringify(HttpMethod::PATCH),
stringify(HttpMethod::DELETE), stringify(HttpMethod::FALLBACK)};
#undef stringify
template <typename E> constexpr auto to_underlying(E e) noexcept {
return static_cast<std::underlying_type_t<E>>(e);
}
class ErrorRoute : public Route {
public:
bool Execute(Request &request, Response &response) override {
response.Reset();
response << "Unhandled Error: Status "
<< std::to_string(to_underlying(response.Status));
response.SetType("text/plain");
return true;
}
};
void RouteTree::Add(const std::string &path, uint32_t allowedMethods,
RouteInstaniateFunction instaniate) {
auto segments = String::Split(path, "/");
auto node = &Root;
for (const auto &segment : segments) {
if (segment.empty())
continue;
if (!node->Children.contains(segment)) {
node->Children[segment] = std::make_unique<Node>(m_NodeID++);
}
node = node->Children.at(segment).get();
}
m_Routes[node->ID] = {.AllowedMethods = allowedMethods,
.Instaniate = std::move(instaniate)};
}
Ref<Route> RouteTree::Find(const std::string &path, Request &request) {
auto segments = String::Split(path, "/");
auto node = &Root;
for (const auto &segment : segments) {
if (segment.empty())
continue;
if (auto it = node->Children.find(segment); it != node->Children.end()) {
node = it->second.get();
} else {
// Arguments...
bool foundParameter = false;
for (auto &[key, child] : node->Children) {
if (key[0] == ':') {
node = child.get();
foundParameter = true;
request.URLParameters[key.substr(1)] = segment;
break;
}
}
if (!foundParameter) {
request.URLParameters = {};
return nullptr;
}
}
}
if (m_Routes.contains(node->ID)) {
const auto &instance = m_Routes[node->ID];
auto ref = instance.Instaniate();
ref->SetAllowedMethods(instance.AllowedMethods);
return ref;
}
request.URLParameters = {};
return nullptr;
}
Router::Router() { Register<ErrorRoute>("@"); }
static void HandleOptions(Ref<Response> &response, uint32_t allowedMethods) {
std::stringstream str{};
bool isFirst = true;
for (auto &[key, value] : s_HttpMethodToString) {
if (allowedMethods & static_cast<uint32_t>(key)) {
if (!isFirst)
str << ", ";
str << value;
isFirst = false;
}
}
response->SetHeader("Allow", str.str());
}
Ref<Response> Router::HandleRoute(Ref<Request> &request) {
auto response = CreateRef<Response>();
auto route = FindRoute(request);
response->CookieData = request->CookieData;
response->SessionData = request->SessionData;
response->Method = request->Method;
if (!route) {
route = m_Tree.Find(s_HttpMethodToString[request->Method] + request->URI,
*request);
if (!route) {
response->SetStatus(HttpStatusCode::NotFound);
m_Tree.Find("@", *request)->Execute(*request, *response);
return response;
}
}
if (!route->IsAllowed(*request)) {
response->SetStatus(HttpStatusCode::Forbidden);
m_Tree.Find("@", *request)->Execute(*request, *response);
return response;
}
if (request->Method == HttpMethod::OPTIONS) {
HandleOptions(response, route->GetAllowedMethods());
return response;
}
if (!route->Execute(*request, *response)) {
std::string rKey = "@" + std::to_string(to_underlying(response->Status));
auto r = m_Tree.Find(rKey, *request);
if (r) {
r->Execute(*request, *response);
} else {
m_Tree.Find("@", *request)->Execute(*request, *response);
}
}
return response;
}
Ref<Route> Router::FindRoute(Ref<Request> &request) {
const auto &url = request->URI;
if (url.starts_with("@"))
return nullptr;
return m_Tree.Find(url, *request);
}
struct InlineRoute : Route {
explicit InlineRoute(RouteFunction function) : Func(std::move(function)) {}
RouteFunction Func;
bool Execute(Request &request, Response &response) override {
Func(request, response);
return true;
}
bool IsAllowed(Request &request) override { return true; }
};
void Router::Get(const std::string &path, const RouteFunction &func) {
m_Tree.Add(s_HttpMethodToString[HttpMethod::GET] + path,
(uint32_t)HttpMethod::GET,
[func] { return std::make_shared<InlineRoute>(func); });
}
void Router::Post(const std::string &path, const RouteFunction &func) {
m_Tree.Add(s_HttpMethodToString[HttpMethod::POST] + path,
(uint32_t)HttpMethod::POST,
[func] { return std::make_shared<InlineRoute>(func); });
}
void Router::Put(const std::string &path, const RouteFunction &func) {
m_Tree.Add(s_HttpMethodToString[HttpMethod::PUT] + path,
(uint32_t)HttpMethod::PUT,
[func] { return std::make_shared<InlineRoute>(func); });
}
void Router::Patch(const std::string &path, const RouteFunction &func) {
m_Tree.Add(s_HttpMethodToString[HttpMethod::PATCH] + path,
(uint32_t)HttpMethod::PATCH,
[func] { return std::make_shared<InlineRoute>(func); });
}
void Router::Delete(const std::string &path, const RouteFunction &func) {
m_Tree.Add(s_HttpMethodToString[HttpMethod::DELETE] + path,
(uint32_t)HttpMethod::DELETE,
[func] { return std::make_shared<InlineRoute>(func); });
}
} // namespace VWeb