/* Simple WebSocket relay facilitating client-to-client connections. * * Author: Steffen Vogel * SPDX-FileCopyrightText: 2014-2023 Institute for Automation of Complex Power Systems, RWTH Aachen University * SPDX-License-Identifier: Apache-2.0 */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "villas-relay.hpp" using namespace villas; using namespace villas::node; namespace villas { namespace node { namespace tools { RelaySession::RelaySession(Relay *r, Identifier sid) : identifier(sid), connects(0) { auto loggerName = fmt::format("relay:{}", sid); logger = villas::logging.get(loggerName); logger->info("Session created: {}", identifier); sessions[sid] = this; created = time(nullptr); uuid::generateFromString(uuid, identifier, r->uuid); } RelaySession::~RelaySession() { logger->info("Session destroyed: {}", identifier); sessions.erase(identifier); } RelaySession *RelaySession::get(Relay *r, lws *wsi) { char uri[64]; /* We use the URI to associate this connection to a session * Example: ws://example.com/node_1 * Will select the session with the name 'node_1' */ // Get path of incoming request lws_hdr_copy(wsi, uri, sizeof(uri), WSI_TOKEN_GET_URI); if (strlen(uri) <= 1) throw InvalidUrlException(); Identifier sid = uri + 1; auto it = sessions.find(sid); if (it == sessions.end()) { auto *rs = new RelaySession(r, sid); if (!rs) throw MemoryAllocationError(); return rs; } else { auto logger = logging.get("villas-relay"); logger->info("Found existing session: {}", sid); return it->second; } } json_t *RelaySession::toJson() const { json_t *json_connections = json_array(); for (auto it : connections) { auto conn = it.second; json_array_append(json_connections, conn->toJson()); } return json_pack("{ s: s, s: s, s: o, s: I, s: i }", "identifier", identifier.c_str(), "uuid", uuid::toString(uuid).c_str(), "connections", json_connections, "created", created, "connects", connects); } std::map RelaySession::sessions; RelayConnection::RelayConnection(Relay *r, lws *w, bool lo) : wsi(w), currentFrame(std::make_shared()), outgoingFrames(), bytes_recv(0), bytes_sent(0), frames_recv(0), frames_sent(0), loopback(lo) { session = RelaySession::get(r, wsi); session->connections[wsi] = this; session->connects++; lws_get_peer_addresses(wsi, lws_get_socket_fd(wsi), name, sizeof(name), ip, sizeof(ip)); created = time(nullptr); session->logger->info("New connection established: {} ({})", name, ip); } RelayConnection::~RelayConnection() { session->logger->info("Connection closed: {} ({})", name, ip); session->connections.erase(wsi); if (session->connections.empty()) delete session; } json_t *RelayConnection::toJson() const { return json_pack("{ s: s, s: s, s: I, s: I, s: I, s: I, s: I }", "name", name, "ip", ip, "created", created, "bytes_recv", bytes_recv, "bytes_sent", bytes_sent, "frames_recv", frames_recv, "frames_sent", frames_sent); } void RelayConnection::write() { if (outgoingFrames.empty()) return; auto fr = outgoingFrames.front(); int ret = lws_write(wsi, fr->data(), fr->size(), LWS_WRITE_BINARY); if (ret < 0) return; bytes_sent += fr->size(); frames_sent++; outgoingFrames.pop(); if (outgoingFrames.size() > 0) lws_callback_on_writable(wsi); } void RelayConnection::read(void *in, size_t len) { currentFrame->insert(currentFrame->end(), (uint8_t *)in, (uint8_t *)in + len); bytes_recv += len; if (lws_is_final_fragment(wsi)) { frames_recv++; session->logger->debug("Received frame, relaying to {} connections", session->connections.size() - (loopback ? 0 : 1)); for (auto p : session->connections) { auto c = p.second; /* We skip the current connection in order * to avoid receiving our own data */ if (loopback == false && c == this) continue; c->outgoingFrames.push(currentFrame); lws_callback_on_writable(c->wsi); } currentFrame = std::make_shared(); } } Relay::Relay(int argc, char *argv[]) : Tool(argc, argv, "relay"), stop(false), context(nullptr), vhost(nullptr), loopback(false), port(8088), protocol("live") { int ret; char hname[128]; ret = gethostname(hname, sizeof(hname)); if (ret) throw SystemError("Failed to get hostname"); hostname = hname; // Default UUID is derived from hostname uuid::generateFromString(uuid, hname); ret = memory::init(0); if (ret) throw RuntimeError("Failed to initialize memory"); // Initialize logging lws_set_log_level(Web::lwsLogLevel(logging.getLevel()), Web::lwsLogger); protocols = {{.name = "http", .callback = lws_callback_http_dummy, .per_session_data_size = 0, .rx_buffer_size = 1024}, {.name = "http-api", .callback = httpProtocolCallback, .per_session_data_size = 0, .rx_buffer_size = 1024}, {.name = "live", .callback = protocolCallback, .per_session_data_size = sizeof(RelayConnection), .rx_buffer_size = 0}, {nullptr /* terminator */}}; } int Relay::httpProtocolCallback(lws *wsi, enum lws_callback_reasons reason, void *user, void *in, size_t len) { int ret; lws_context *ctx = lws_get_context(wsi); void *user_ctx = lws_context_user(ctx); Relay *r = reinterpret_cast(user_ctx); switch (reason) { case LWS_CALLBACK_HTTP: { unsigned char buf[LWS_PRE + 2048], *start = &buf[LWS_PRE], *end = &buf[sizeof(buf) - LWS_PRE - 1], *p = start; if (lws_add_http_common_headers( wsi, HTTP_STATUS_OK, "application/json", LWS_ILLEGAL_HTTP_CONTENT_LEN, // no content len &p, end)) return 1; if (lws_finalize_write_http_header(wsi, start, &p, end)) return 1; // Write the body separately lws_callback_on_writable(wsi); return 0; } case LWS_CALLBACK_HTTP_WRITEABLE: { size_t len; std::vector buf; json_t *json_sessions, *json_body; json_sessions = json_array(); for (auto it : RelaySession::sessions) { auto &session = it.second; json_array_append(json_sessions, session->toJson()); } json_body = json_pack("{ s: o, s: s, s: s, s: s, s: { s: b, s: i, s: s } }", "sessions", json_sessions, "version", PROJECT_VERSION_STR, "hostname", r->hostname.c_str(), "uuid", uuid::toString(r->uuid).c_str(), "options", "loopback", r->loopback, "port", r->port, "protocol", r->protocol.c_str()); if (!json_body) return -1; len = 1024; do { buf.resize(LWS_PRE + len); len = json_dumpb(json_body, buf.data() + LWS_PRE, buf.size() - LWS_PRE, JSON_INDENT(4)); if (len == 0) return -1; } while (len > buf.size() - LWS_PRE); ret = lws_write(wsi, (unsigned char *)(buf.data() + LWS_PRE), len, LWS_WRITE_HTTP_FINAL); if (ret < 0) return ret; r->logger->info("Handled API request"); return -1; } default: break; } return lws_callback_http_dummy(wsi, reason, user, in, len); } int Relay::protocolCallback(lws *wsi, enum lws_callback_reasons reason, void *user, void *in, size_t len) { lws_context *ctx = lws_get_context(wsi); void *user_ctx = lws_context_user(ctx); Relay *r = reinterpret_cast(user_ctx); RelayConnection *c = reinterpret_cast(user); switch (reason) { case LWS_CALLBACK_ESTABLISHED: try { new (c) RelayConnection(r, wsi, r->loopback); } catch (const InvalidUrlException &e) { lws_close_reason(wsi, LWS_CLOSE_STATUS_PROTOCOL_ERR, (unsigned char *)"Invalid URL", strlen("Invalid URL")); return -1; } break; case LWS_CALLBACK_CLOSED: c->~RelayConnection(); break; case LWS_CALLBACK_SERVER_WRITEABLE: c->write(); break; case LWS_CALLBACK_RECEIVE: c->read(in, len); break; default: break; } return 0; } void Relay::usage() { std::cout << "Usage: villas-relay [OPTIONS]" << std::endl << " OPTIONS is one or more of the following options:" << std::endl << " -d LVL set debug level" << std::endl << " -p PORT the port number to listen on" << std::endl << " -P PROT the websocket protocol" << std::endl << " -l enable loopback of own data" << std::endl << " -u UUID unique instance id" << std::endl << " -V show version and exit" << std::endl << " -h show usage and exit" << std::endl << std::endl; printCopyright(); } void Relay::parse() { int ret; char c, *endptr; while ((c = getopt(argc, argv, "hVp:P:ld:u:")) != -1) { switch (c) { case 'd': logging.setLevel(optarg); lws_set_log_level(Web::lwsLogLevel(logging.getLevel()), Web::lwsLogger); break; case 'p': port = strtoul(optarg, &endptr, 10); goto check; case 'P': protocol = optarg; break; case 'l': loopback = true; break; case 'u': ret = uuid_parse(optarg, uuid); if (ret) { logger->error("Failed to parse UUID: {}", optarg); exit(EXIT_FAILURE); } break; case 'V': printVersion(); exit(EXIT_SUCCESS); case 'h': case '?': usage(); exit(c == '?' ? EXIT_FAILURE : EXIT_SUCCESS); } continue; check: if (optarg == endptr) { logger->error("Failed to parse parse option argument '-{} {}'", c, optarg); exit(EXIT_FAILURE); } } if (argc - optind < 0) { usage(); exit(EXIT_FAILURE); } } int Relay::main() { // Start server lws_context_creation_info ctx_info = {0}; protocols[2].name = protocol.c_str(); ctx_info.options = LWS_SERVER_OPTION_EXPLICIT_VHOSTS | LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT; ctx_info.gid = -1; ctx_info.uid = -1; ctx_info.protocols = protocols.data(); #ifndef LWS_WITHOUT_EXTENSIONS ctx_info.extensions = extensions.data(); #endif ctx_info.port = port; ctx_info.mounts = &mount; ctx_info.user = (void *)this; auto lwsLogger = logging.get("lws"); context = lws_create_context(&ctx_info); if (context == nullptr) { lwsLogger->error("Failed to initialize server context"); exit(EXIT_FAILURE); } vhost = lws_create_vhost(context, &ctx_info); if (vhost == nullptr) { lwsLogger->error("Failed to initialize virtual host"); exit(EXIT_FAILURE); } while (!stop) lws_service(context, 100); return 0; } const std::vector Relay::extensions = { #ifdef LWS_DEFLATE_FOUND {"permessage-deflate", lws_extension_callback_pm_deflate, "permessage-deflate"}, {"deflate-frame", lws_extension_callback_pm_deflate, "deflate_frame"}, #endif // LWS_DEFLATE_FOUND {nullptr /* terminator */}}; const lws_http_mount Relay::mount = { .mount_next = nullptr, // linked-list "next" .mountpoint = "/api/v1", // mountpoint URL .origin = nullptr, // protocol .def = nullptr, .protocol = "http-api", .cgienv = nullptr, .extra_mimetypes = nullptr, .interpret = nullptr, .cgi_timeout = 0, .cache_max_age = 0, .auth_mask = 0, .cache_reusable = 0, .cache_revalidate = 0, .cache_intermediaries = 0, .origin_protocol = LWSMPRO_CALLBACK, // dynamic .mountpoint_len = 7, // char count .basic_auth_login_file = nullptr, }; } // namespace tools } // namespace node } // namespace villas int main(int argc, char *argv[]) { villas::node::tools::Relay t(argc, argv); return t.run(); }