diff --git a/CMakeLists.txt b/CMakeLists.txt index 354df8e..e7382cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ set(SOURCE_FILES Source/StringUtils.cpp Source/Cookie.cpp Source/Session.cpp - Source/Response.cpp) + Source/Response.cpp Source/InbuildMiddleWare.cpp) include_directories(${CMAKE_SOURCE_DIR}/) add_library(VWeb ${SOURCE_FILES}) diff --git a/Source/InbuildMiddleWare.cpp b/Source/InbuildMiddleWare.cpp new file mode 100644 index 0000000..f4d70ca --- /dev/null +++ b/Source/InbuildMiddleWare.cpp @@ -0,0 +1,121 @@ +#include "StringUtils.h" +#include "VWeb.h" + +#include +#include + +namespace VWeb { +#pragma region AUTH +PreMiddleWareReturn AuthWare::PreHandle(Request &request) { + if (m_AuthFunction) + return m_AuthFunction(request); + return {}; +} + +#pragma endregion AUTH + +#pragma region SESSION + +static std::string GenerateSID() { + static std::random_device dev; + static std::mt19937 rng(dev()); + + std::uniform_int_distribution dist(0, 15); + + const char *v = "0123456789abcdef"; + const bool dashArray[] = {false, false, false, false, true, false, + true, false, true, false, true, false, + false, false, false, false}; + + std::stringstream res; + for (bool dash : dashArray) { + if (dash) + res << "-"; + res << v[dist(rng)]; + res << v[dist(rng)]; + } + return res.str(); +} + +SessionManager::SessionManager() { + m_GCThread = CreateRef([this]() { + while (m_IsRunning) { + if (m_Counter == 59) { + GC(); + m_Counter = -1; + } + m_Counter++; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + }); +} +SessionManager::~SessionManager() { + m_IsRunning = false; + m_GCThread->join(); +} + +PreMiddleWareReturn SessionManager::PreHandle(Request &request) { + auto &cookies = request.CookieData; + if (!cookies->Has("sid")) + return {}; + auto &cookie = cookies->Get("sid"); + std::lock_guard lck(m_Mutex); + if (!m_Sessions.contains(cookie.Value)) + return {}; + auto &session = m_Sessions[cookie.Value]; + session->Update(); + request.SessionData = session; + return {}; +} + +bool SessionManager::PostHandle(const Request &request, Response &response) { + if (response.SessionData->Id.empty() && + response.SessionData->ContainsData()) { + response.SessionData->Update(); + { + std::lock_guard lck(m_Mutex); + response.SessionData->Id = GenerateSID(); + while (m_Sessions.contains(response.SessionData->Id)) + response.SessionData->Id = GenerateSID(); + } + m_Sessions[response.SessionData->Id] = response.SessionData; + auto &sidCookie = response.CookieData->Get("sid"); + sidCookie.HttpOnly = true; + sidCookie.Secure = true; + sidCookie.Value = response.SessionData->Id; + } + return true; +} + +void SessionManager::GC() { + std::lock_guard lck(m_Mutex); + std::vector mark_as_delete; + for (auto &itemToCollect : m_Sessions) { + if (!itemToCollect.second->IsValid()) + mark_as_delete.push_back(itemToCollect.first); + } + + for (auto &i : mark_as_delete) + m_Sessions.erase(i.data()); +} +#pragma endregion SESSION +#pragma region COOKIES +PreMiddleWareReturn CookieManager::PreHandle(Request &request) { + auto &cookieHeaders = request.Header("Cookie"); + auto &cookies = request.CookieData; + if (cookieHeaders.Size() > 0) { + auto &values = cookieHeaders.Values(); + for (auto &rawCookie : cookieHeaders.Values()) { + auto splitCookies = String::Split(rawCookie, ";"); + for (auto &cookie : splitCookies) { + auto split = String::Split(cookie, "="); + String::Trim(split[0]); + String::Trim(split[1]); + cookies->CreateOld(split[0], split[1]); + } + } + } + return {}; +} +#pragma endregion COOKIES +} // namespace VWeb \ No newline at end of file diff --git a/Source/MiddleWare.cpp b/Source/MiddleWare.cpp index c25bae2..3b0e8b9 100644 --- a/Source/MiddleWare.cpp +++ b/Source/MiddleWare.cpp @@ -1,15 +1,21 @@ #include namespace VWeb { -void MiddleWareHandler::HandlePre(Ref &request) { +std::optional> +MiddleWareHandler::HandlePre(Ref &request) { for (auto &[key, middleWare] : m_MiddleWares) { - middleWare->PreHandle(*request); + auto data = middleWare->PreHandle(*request); + if (data.has_value()) + return data; } + return {}; } void MiddleWareHandler::HandlePost(Ref &request, Ref &response) { for (auto &[key, middleWare] : m_MiddleWares) { - middleWare->PostHandle(*request, *response); + if (!middleWare->PostHandle(*request, *response)) { + break; + } } } void MiddleWareHandler::Shutdown(Ref &request, diff --git a/Source/RequestHandler.cpp b/Source/RequestHandler.cpp index d555e6e..e309b46 100644 --- a/Source/RequestHandler.cpp +++ b/Source/RequestHandler.cpp @@ -55,8 +55,15 @@ struct RequestJob : public WorkerJob { } MRequest->CookieData = CreateRef(); MRequest->SessionData = CreateRef(); - MMiddleWareHandler->HandlePre(MRequest); - auto response = MRouter->HandleRoute(MRequest); + Ref response; + auto preValue = MMiddleWareHandler->HandlePre(MRequest); + if (preValue.has_value()) { + response = preValue.value(); + response->SessionData = MRequest->SessionData; + response->CookieData = MRequest->CookieData; + } else { + response = MRouter->HandleRoute(MRequest); + } MMiddleWareHandler->HandlePost(MRequest, response); auto content = response->GetResponse(); MRequestHandler->AddSendResponse( diff --git a/Source/Response.cpp b/Source/Response.cpp index 7250b9b..bbdb527 100644 --- a/Source/Response.cpp +++ b/Source/Response.cpp @@ -65,6 +65,13 @@ static std::unordered_map s_HTTPCodeToString = { {HttpStatusCode::NetworkAuthenticationRequired,"511 Network Authentication Required"} }; // clang-format on + +Ref Response::FromCode(HttpStatusCode code) { + auto response = CreateRef(); + response->SetStatus(code); + return response; +} + std::string Response::GetResponse() { std::string content = m_Content.str(); auto headData = TransformHeaders(content); diff --git a/Source/Route.cpp b/Source/Route.cpp index 11f5458..dbd3b89 100644 --- a/Source/Route.cpp +++ b/Source/Route.cpp @@ -80,4 +80,8 @@ bool Route::SupportsMethod(const Request &request) { std::find(m_AllowedMethods.begin(), m_AllowedMethods.end(), request.Method) != m_AllowedMethods.end(); } +void Route::AllowMethod(HttpMethod method) { + if (std::find(m_AllowedMethods.begin(), m_AllowedMethods.end(), method) == m_AllowedMethods.end()) + m_AllowedMethods.push_back(method); +} } // namespace VWeb \ No newline at end of file diff --git a/Source/Router.cpp b/Source/Router.cpp index 4d38247..de1130d 100644 --- a/Source/Router.cpp +++ b/Source/Router.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace VWeb { @@ -20,6 +21,33 @@ public: } }; +class InstanceHandleRoute : public Route { +public: + bool Get(const Request &request, Response &response) override { + return GetFunc && GetFunc(request, response); + } + bool Post(const Request &request, Response &response) override { + return PostFunc && PostFunc(request, response); + } + bool Put(const Request &request, Response &response) override { + return PutFunc && PutFunc(request, response); + } + bool Patch(const Request &request, Response &response) override { + return PatchFunc && PatchFunc(request, response); + } + bool Delete(const Request &request, Response &response) override { + return DeleteFunc && DeleteFunc(request, response); + } + +protected: + RouteFunction GetFunc{nullptr}; + RouteFunction PostFunc{nullptr}; + RouteFunction PutFunc{nullptr}; + RouteFunction PatchFunc{nullptr}; + RouteFunction DeleteFunc{nullptr}; + friend Router; +}; + Router::Router() { m_Routes["@"] = CreateRef(); } void Router::AddRoute(const std::string &name, const Ref &route) { @@ -38,6 +66,7 @@ Ref Router::HandleRoute(Ref &request) { auto response = CreateRef(); auto route = FindRoute(request); response->CookieData = request->CookieData; + response->SessionData = request->SessionData; response->Method = request->Method; if (!route) { response->SetStatus(HttpStatusCode::NotFound); @@ -94,4 +123,38 @@ void Router::AddToArgs(Ref &request, Vector &items) { request->URLParameters.push_back(items[items.size() - 1]); items.pop_back(); } + +InstanceHandleRoute* GetOrNull(const std::string& path, Router* router) { + InstanceHandleRoute* route; + if (!router->m_Routes.contains(path)) { + router->AddRoute(path, CreateRef()); + } + route = dynamic_cast(router->m_Routes[path].get()); + return route; +} +void Router::Get(const std::string &path, RouteFunction func) { + auto* route = GetOrNull(path, this); + if (route) + route->GetFunc = std::move(func); +} +void Router::Post(const std::string &path, RouteFunction func) { + auto* route = GetOrNull(path, this); + if (route) + route->PostFunc = std::move(func); +} +void Router::Put(const std::string &path, RouteFunction func) { + auto* route = GetOrNull(path, this); + if (route) + route->PutFunc = std::move(func); +} +void Router::Patch(const std::string &path, RouteFunction func) { + auto* route = GetOrNull(path, this); + if (route) + route->PatchFunc = std::move(func); +} +void Router::Delete(const std::string &path, RouteFunction func) { + auto* route = GetOrNull(path, this); + if (route) + route->DeleteFunc = std::move(func); +} } // namespace VWeb \ No newline at end of file diff --git a/Source/Server.cpp b/Source/Server.cpp index 5bb8719..59327e2 100644 --- a/Source/Server.cpp +++ b/Source/Server.cpp @@ -8,6 +8,10 @@ Server::Server() { m_ServerConfig->EPoll = CreateRef(); m_ServerConfig->Socket = CreateRef(m_ServerConfig); m_RequestHandler = CreateRef(m_ServerConfig->Socket); + auto& middleWare = m_RequestHandler->Middleware(); + middleWare->Create(); + middleWare->Create(); + middleWare->Create(); m_RequestHandler->m_Server = this; }; void Server::LoadSharedLibs() { @@ -132,4 +136,8 @@ void Server::CreateRequest(int sockID) { m_RawRequest.Remove(sockID); m_RequestHandler->AddRequest(request); } + +Ref &Server::Middleware() { + return m_RequestHandler->Middleware(); +} } // namespace VWeb \ No newline at end of file diff --git a/VWeb.h b/VWeb.h index 2ec23ef..815c6c4 100644 --- a/VWeb.h +++ b/VWeb.h @@ -14,7 +14,9 @@ #include #include #include +#include #include +#include namespace VWeb { @@ -253,6 +255,7 @@ protected: bool m_IsErrored{false}; }; +class MiddleWareHandler; class RequestHandler; class Response; class Router; @@ -272,6 +275,7 @@ public: void AddRoute(const std::string &path, const Ref &route); void RemoveRoute(const std::string &path); + Ref& Middleware(); protected: void Execute(); void OutgoingExecute(epoll_event &event); @@ -305,6 +309,8 @@ struct Session { void Remove(const std::string &key); bool Has(const std::string &key); Ref &operator[](const std::string &key) { return m_Data[key]; } + void SetSessionData(const std::string& key, const Ref& data) { m_Data[key] = data; } + bool ContainsData() { return !m_Data.empty(); } protected: std::chrono::time_point m_LastCall = @@ -439,6 +445,9 @@ public: bool HasParameter(const std::string &key) const { return Parameters.contains(key); } + bool HasHeader(const std::string &key) const { + return Headers.contains(key); + } std::string &FirstOf(const std::string &key) { return Parameters[key].GetFirst(); } @@ -448,6 +457,8 @@ public: Vector URLParameters; }; class Response { +public: + static Ref FromCode(HttpStatusCode code); public: size_t Length{0}; Ref CookieData{nullptr}; @@ -498,6 +509,7 @@ protected: friend Server; }; +typedef std::function RouteFunction; class Route { public: Route() = default; @@ -514,6 +526,8 @@ public: bool SupportsMethod(const Request &request); virtual bool IsAllowed(const Request &request); + void AllowMethod(HttpMethod method); + protected: bool m_AllowAll{true}; Vector m_AllowedMethods; @@ -531,43 +545,86 @@ public: Ref FindRoute(Ref &request); static void AddToArgs(Ref &request, Vector &items); -protected: +public: + void Get(const std::string& path, RouteFunction); + void Post(const std::string& path, RouteFunction); + void Put(const std::string& path, RouteFunction); + void Patch(const std::string& path, RouteFunction); + void Delete(const std::string& path, RouteFunction); std::unordered_map> m_Routes; }; +typedef std::optional> PreMiddleWareReturn; struct MiddleWare { int Pos{0}; - virtual void PreHandle(Request &){}; - virtual void PostHandle(Request &, Response &){}; - virtual void Shutdown(Request &, Response &){}; + virtual PreMiddleWareReturn PreHandle(Request &){ return {}; } + virtual bool PostHandle(const Request &, Response &){ return true; } + virtual void Shutdown(const Request &, const Response &){}; bool operator<(const MiddleWare *rhs) const { return Pos < rhs->Pos; } }; class MiddleWareHandler { public: - void HandlePre(Ref &); + PreMiddleWareReturn HandlePre(Ref &); void HandlePost(Ref &, Ref &); void Shutdown(Ref &, Ref &); public: - template Ref Get() { return GetById(typeid(T).name()); } + template Ref GetRef() { return GetById(typeid(T).name()); } + template T& Get() { return static_cast(*GetById(typeid(T).name())); } template void Set(Ref &instance) { auto &type = typeid(T); if (type.before(typeid(MiddleWare))) SetById(type.name(), instance); } - template Ref Create() { - return SetById(typeid(T).name(), CreateRef()); + template T& Create() { + return static_cast(*CreateMiddleWare()); } template void Remove() { RemoveById(typeid(T).name()); } - protected: + template Ref CreateMiddleWare() { + return SetById(typeid(T).name(), CreateRef()); + } Ref GetById(const char *id); Ref SetById(const char *id, const Ref &); void RemoveById(const char *id); std::map> m_MiddleWares; }; +typedef std::function AuthFunction; +class AuthWare : public MiddleWare { +public: + AuthWare() = default; + ~AuthWare() = default; + PreMiddleWareReturn PreHandle(Request &request) override; + void SetAuthMethod(AuthFunction function) { m_AuthFunction = std::move(function); } +protected: + AuthFunction m_AuthFunction{nullptr}; +}; + +class SessionManager : public MiddleWare { +public: + SessionManager(); + ~SessionManager(); + PreMiddleWareReturn PreHandle(Request& request) override; + bool PostHandle(const Request& request, Response& response) override; +protected: + void GC(); +protected: + Ref m_GCThread; + std::mutex m_Mutex; + std::unordered_map> m_Sessions; + int m_Counter{-1}; + bool m_IsRunning{true}; +}; + +class CookieManager : public MiddleWare { +public: + CookieManager() = default; + ~CookieManager() = default; + PreMiddleWareReturn PreHandle(Request&) override; +}; + #pragma endregion VWEB_ROUTING #pragma endregion VWEB } // namespace VWeb \ No newline at end of file diff --git a/dist/libVWeb.debug.a b/dist/libVWeb.debug.a index 999123b..635722c 100644 Binary files a/dist/libVWeb.debug.a and b/dist/libVWeb.debug.a differ diff --git a/dist/libVWeb.release.a b/dist/libVWeb.release.a index 79fbe2c..3c333d9 100644 Binary files a/dist/libVWeb.release.a and b/dist/libVWeb.release.a differ diff --git a/example/main.cpp b/example/main.cpp index 76c259c..5584582 100644 --- a/example/main.cpp +++ b/example/main.cpp @@ -1,19 +1,18 @@ -#include #include -class BigDataRoute : public VWeb::Route { -public: - bool Execute(const VWeb::Request &request, VWeb::Response &response) override { - response << "
\n";
-    for (int i = 0; i < 100; ++i)
-      response << "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet.\n";
-    response << "
"; - return true; - } -}; +bool Ping(const VWeb::Request&, VWeb::Response& response) { + response << "Pong"; + return true; +} int main() { + using namespace VWeb; VWeb::Server server; - server.AddRoute("/test", VWeb::CreateRef()); + auto& router = server.GetRouter(); + router->Get("/test", [](const Request&, Response& response) { + response << "NICE"; + return true; + }); + router->Get("/ping", &Ping); server.Start(); server.Join(); return 0;