From 8375f92802d3aa7667bfc6f22f6d2a72361c9808 Mon Sep 17 00:00:00 2001 From: Star Rauchenberger Date: Mon, 4 Nov 2024 11:46:24 -0500 Subject: Websockets server! --- server_main.cpp | 230 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 server_main.cpp (limited to 'server_main.cpp') diff --git a/server_main.cpp b/server_main.cpp new file mode 100644 index 0000000..8da6477 --- /dev/null +++ b/server_main.cpp @@ -0,0 +1,230 @@ +#define ASIO_STANDALONE + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cardset.h" +#include "database.h" +#include "imagestore.h" +#include "wizard.h" + +namespace { + +using socket_type = websocketpp::server; + +class server { + public: + server(const cardset& cards, const imagestore& images, std::mt19937& rng) + : cards_(cards), images_(images), rng_(rng) { + socket_.init_asio(); + + socket_.set_message_handler([this](websocketpp::connection_hdl connection, + socket_type::message_ptr message) { + on_message(connection, message); + }); + } + + void run() { + socket_.listen(9002); + socket_.start_accept(); + + asio::post(std::bind(&server::cleanup_thread, this)); + + socket_.run(); + } + + private: + void on_message(websocketpp::connection_hdl connection, + socket_type::message_ptr message) { + nlohmann::json msgJson; + + try { + msgJson = nlohmann::json::parse(message->get_payload()); + } catch (const std::exception& ex) { + std::string response = R"( + { + "type": "error", + "msg": "Could not parse message." + } + )"; + + socket_.send(connection, response, + websocketpp::frame::opcode::value::TEXT); + return; + } + + std::string cmd = msgJson["cmd"]; + if (cmd == "generate") { + cmd_generate(connection, msgJson["text"]); + } else if (cmd == "check") { + cmd_check(connection, msgJson["token"]); + } else { + std::string response = R"( + { + "type": "error", + "msg": "No command in message." + } + )"; + + socket_.send(connection, response, + websocketpp::frame::opcode::value::TEXT); + return; + } + } + + void cmd_generate(websocketpp::connection_hdl connection, std::string text) { + std::string token; + + { + std::lock_guard rng_guard(rng_mutex_); + token = database_.create(rng_); + } + + nlohmann::json tokenMsg; + tokenMsg["type"] = "token"; + tokenMsg["token"] = token; + socket_.send(connection, tokenMsg.dump(), + websocketpp::frame::opcode::value::TEXT); + + database_.subscribe(token, [this, connection](const std::string& msg) { + socket_.send(connection, msg, websocketpp::frame::opcode::value::TEXT); + }); + + asio::post(std::bind(&server::generate_thread, this, token, text)); + } + + void cmd_check(websocketpp::connection_hdl connection, std::string token) { + bool failed = false; + + try { + database_.subscribe(token, [this, connection](const std::string& msg) { + socket_.send(connection, msg, websocketpp::frame::opcode::value::TEXT); + }); + + if (!database_.is_done(token)) { + return; + } + + std::string result = database_.getResult(token); + nlohmann::json resultMsg; + if (result.empty()) { + resultMsg["type"] = "error"; + resultMsg["msg"] = "Unknown error occurred."; + } else { + resultMsg["type"] = "result"; + resultMsg["image"] = database_.getResult(token); + resultMsg["msg"] = "Success!"; + } + + socket_.send(connection, resultMsg.dump(), + websocketpp::frame::opcode::value::TEXT); + } catch (const std::exception& ex) { + failed = true; + } + + if (failed) { + try { + socket_.send(connection, R"( + { + "type": "error", + "msg": "Error retrieving request." + })", + websocketpp::frame::opcode::value::TEXT); + } catch (const std::exception& ex) { + // Well, okay + std::cout << ex.what() << std::endl; + } + } + } + + void generate_thread(std::string token, std::string text) { + std::unique_ptr generator; + + { + std::lock_guard rng_guard(rng_mutex_); + generator = std::make_unique(cards_, images_, text, rng_); + } + + try { + generator->set_status_callback([this, token](const std::string& status) { + nlohmann::json msg; + msg["type"] = "status"; + msg["msg"] = status; + + database_.post(token, msg.dump()); + }); + + Magick::Image resultImage = generator->run(); + Magick::Blob resultBlob; + resultImage.write(&resultBlob); + + std::string resultBytes((const char*)resultBlob.data(), + resultBlob.length()); + std::string resultEncoded = base64::to_base64(resultBytes); + + database_.setResult(token, resultEncoded); + + nlohmann::json resultMsg; + resultMsg["type"] = "result"; + resultMsg["image"] = resultEncoded; + resultMsg["msg"] = "Success!"; + + database_.post(token, resultMsg.dump()); + } catch (const std::exception& ex) { + nlohmann::json response; + response["type"] = "error"; + response["msg"] = + std::string("Error generating card (") + ex.what() + ")"; + + database_.post(token, response.dump()); + } + + database_.mark_done(token); + } + + void cleanup_thread() { + for (;;) { + // sleep + } + } + + const cardset& cards_; + const imagestore& images_; + + std::mutex rng_mutex_; + std::mt19937& rng_; + + socket_type socket_; + database database_; +}; + +} // namespace + +int main(int argc, char** argv) { + Magick::InitializeMagick(nullptr); + + std::random_device randomDevice; + std::mt19937 rng(5); // randomDevice()); + + if (argc != 2) { + std::cout << "usage: wizard_server [configfile]" << std::endl; + return -1; + } + + std::ifstream config_file(argv[1]); + nlohmann::json config_data = nlohmann::json::parse(config_file); + + cardset cards(config_data["cards_path"]); + imagestore images(config_data["cache_path"]); + + server app(cards, images, rng); + app.run(); +} -- cgit 1.4.1