From 3fb55b243e0e782545d26faced2ae5660e02c5a6 Mon Sep 17 00:00:00 2001 From: Jan Kaluza Date: Tue, 24 Nov 2015 21:00:04 +0100 Subject: [PATCH] Working Slack RTM and message receiving. --- include/transport/WebSocketClient.h | 79 +++++++ spectrum/src/frontends/slack/SlackAPI.cpp | 128 +++++++++++- spectrum/src/frontends/slack/SlackAPI.h | 40 +++- .../src/frontends/slack/SlackInstallation.cpp | 49 +++-- .../src/frontends/slack/SlackInstallation.h | 6 +- spectrum/src/frontends/slack/SlackRTM.cpp | 116 ++++++----- spectrum/src/frontends/slack/SlackRTM.h | 40 +++- src/WebSocketClient.cpp | 193 ++++++++++++++++++ 8 files changed, 556 insertions(+), 95 deletions(-) create mode 100644 include/transport/WebSocketClient.h create mode 100644 src/WebSocketClient.cpp diff --git a/include/transport/WebSocketClient.h b/include/transport/WebSocketClient.h new file mode 100644 index 00000000..878025f6 --- /dev/null +++ b/include/transport/WebSocketClient.h @@ -0,0 +1,79 @@ +/** + * Spectrum 2 Slack Frontend + * + * Copyright (C) 2015, Jan Kaluza + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "Swiften/Version.h" + +#define HAVE_SWIFTEN_3 (SWIFTEN_VERSION >= 0x030000) + +#if HAVE_SWIFTEN_3 +#include +#endif + +#include +#include +#include + +#include + +namespace Transport { + +class Component; + +class WebSocketClient { + public: + WebSocketClient(Component *component); + + virtual ~WebSocketClient(); + + void connectServer(const std::string &u); + + void write(const std::string &data); + + boost::signal onPayloadReceived; + + private: + void handleDNSResult(const std::vector&, boost::optional); + void handleDataRead(boost::shared_ptr data); + void handleConnected(bool error); + + private: + Component *m_component; + boost::shared_ptr m_dnsQuery; + boost::shared_ptr m_conn; + Swift::TLSConnectionFactory *m_tlsConnectionFactory; + Swift::PlatformTLSFactories *m_tlsFactory; + std::string m_host; + std::string m_path; + std::string m_buffer; + bool m_upgraded; +}; + +} diff --git a/spectrum/src/frontends/slack/SlackAPI.cpp b/spectrum/src/frontends/slack/SlackAPI.cpp index 5f560f0e..3b698225 100644 --- a/spectrum/src/frontends/slack/SlackAPI.cpp +++ b/spectrum/src/frontends/slack/SlackAPI.cpp @@ -36,8 +36,9 @@ namespace Transport { DEFINE_LOGGER(logger, "SlackAPI"); -SlackAPI::SlackAPI(Component *component, UserInfo uinfo) : m_uinfo(uinfo) { +SlackAPI::SlackAPI(Component *component, const std::string &token) { m_component = component; + m_token = token; } SlackAPI::~SlackAPI() { @@ -52,7 +53,7 @@ void SlackAPI::sendMessage(const std::string &from, const std::string &to, const url += "&username=" + Util::urlencode(from); url += "&channel=" + Util::urlencode(to); url += "&text=" + Util::urlencode(text); - url += "&token=" + Util::urlencode(m_uinfo.encoding); + url += "&token=" + Util::urlencode(m_token); HTTPRequest *req = new HTTPRequest(THREAD_POOL(m_component), HTTPRequest::Get, url, boost::bind(&SlackAPI::handleSendMessage, this, _1, _2, _3, _4)); @@ -84,7 +85,7 @@ std::string SlackAPI::getChannelId(HTTPRequest *req, bool ok, rapidjson::Documen } void SlackAPI::imOpen(const std::string &uid, HTTPRequest::Callback callback) { - std::string url = "https://slack.com/api/im.open?user=" + Util::urlencode(uid) + "&token=" + Util::urlencode(m_uinfo.encoding); + std::string url = "https://slack.com/api/im.open?user=" + Util::urlencode(uid) + "&token=" + Util::urlencode(m_token); HTTPRequest *req = new HTTPRequest(THREAD_POOL(m_component), HTTPRequest::Get, url, callback); queueRequest(req); } @@ -125,9 +126,128 @@ std::string SlackAPI::getOwnerId(HTTPRequest *req, bool ok, rapidjson::Document } void SlackAPI::usersList(HTTPRequest::Callback callback) { - std::string url = "https://slack.com/api/users.list?presence=0&token=" + Util::urlencode(m_uinfo.encoding); + std::string url = "https://slack.com/api/users.list?presence=0&token=" + Util::urlencode(m_token); HTTPRequest *req = new HTTPRequest(THREAD_POOL(m_component), HTTPRequest::Get, url, callback); queueRequest(req); } +#define GET_ARRAY(FROM, NAME) rapidjson::Value &NAME = FROM[#NAME]; \ + if (!NAME.IsArray()) { \ + LOG4CXX_ERROR(logger, "No '" << #NAME << "' object in the reply."); \ + return; \ + } + +#define STORE_STRING(FROM, NAME) rapidjson::Value &NAME##_tmp = FROM[#NAME]; \ + if (!NAME##_tmp.IsString()) { \ + LOG4CXX_ERROR(logger, "No '" << #NAME << "' string in the reply."); \ + LOG4CXX_ERROR(logger, data); \ + return; \ + } \ + std::string NAME = NAME##_tmp.GetString(); + +#define STORE_BOOL(FROM, NAME) rapidjson::Value &NAME##_tmp = FROM[#NAME]; \ + if (!NAME##_tmp.IsBool()) { \ + LOG4CXX_ERROR(logger, "No '" << #NAME << "' string in the reply."); \ + LOG4CXX_ERROR(logger, data); \ + return; \ + } \ + bool NAME = NAME##_tmp.GetBool(); + +void SlackAPI::getSlackChannelInfo(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data, std::map &ret) { + if (!ok) { + LOG4CXX_ERROR(logger, req->getError()); + return; + } + + GET_ARRAY(resp, channels); + + for (int i = 0; i < channels.Size(); i++) { + if (!channels[i].IsObject()) { + continue; + } + + SlackChannelInfo info; + + STORE_STRING(channels[i], id); + info.id = id; + + STORE_STRING(channels[i], name); + info.name = name; + + rapidjson::Value &members = channels[i]["members"]; + for (int y = 0; members.IsArray() && y < members.Size(); y++) { + if (!members[i].IsString()) { + continue; + } + + info.members.push_back(members[i].GetString()); + } + + ret[info.id] = info; + } + + return; +} + +void SlackAPI::getSlackImInfo(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data, std::map &ret) { + if (!ok) { + LOG4CXX_ERROR(logger, req->getError()); + return; + } + + GET_ARRAY(resp, ims); + + for (int i = 0; i < ims.Size(); i++) { + if (!ims[i].IsObject()) { + continue; + } + + SlackImInfo info; + + STORE_STRING(ims[i], id); + info.id = id; + + STORE_STRING(ims[i], user); + info.user = user; + + ret[info.id] = info; + LOG4CXX_INFO(logger, info.id << " " << info.user); + } + + return; +} + +void SlackAPI::getSlackUserInfo(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data, std::map &ret) { + if (!ok) { + LOG4CXX_ERROR(logger, req->getError()); + return; + } + + GET_ARRAY(resp, users); + + for (int i = 0; i < users.Size(); i++) { + if (!users[i].IsObject()) { + continue; + } + + SlackUserInfo info; + + STORE_STRING(users[i], id); + info.id = id; + + STORE_STRING(users[i], name); + info.name = name; + + STORE_BOOL(users[i], is_primary_owner); + info.isPrimaryOwner = is_primary_owner; + + ret[info.id] = info; + LOG4CXX_INFO(logger, info.id << " " << info.name); + } + + return; +} + + + } diff --git a/spectrum/src/frontends/slack/SlackAPI.h b/spectrum/src/frontends/slack/SlackAPI.h index 7037938b..039247d8 100644 --- a/spectrum/src/frontends/slack/SlackAPI.h +++ b/spectrum/src/frontends/slack/SlackAPI.h @@ -22,7 +22,6 @@ #include "transport/HTTPRequestQueue.h" #include "transport/HTTPRequest.h" -#include "transport/StorageBackend.h" #include "rapidjson/document.h" #include @@ -36,11 +35,40 @@ namespace Transport { class Component; class StorageBackend; class HTTPRequest; -class SlackRTM; + +class SlackChannelInfo { + public: + SlackChannelInfo() {} + virtual ~SlackChannelInfo() {} + + std::string id; + std::string name; + std::vector members; +}; + +class SlackImInfo { + public: + SlackImInfo() {} + virtual ~SlackImInfo() {} + + std::string id; + std::string user; +}; + +class SlackUserInfo { + public: + SlackUserInfo() : isPrimaryOwner(false) {} + virtual ~SlackUserInfo() {} + + std::string id; + std::string name; + bool isPrimaryOwner; +}; + class SlackAPI : public HTTPRequestQueue { public: - SlackAPI(Component *component, UserInfo uinfo); + SlackAPI(Component *component, const std::string &token); virtual ~SlackAPI(); @@ -52,12 +80,16 @@ class SlackAPI : public HTTPRequestQueue { void sendMessage(const std::string &from, const std::string &to, const std::string &text); + static void getSlackChannelInfo(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data, std::map &channels); + static void getSlackImInfo(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data, std::map &ims); + static void getSlackUserInfo(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data, std::map &users); + private: void handleSendMessage(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data); private: Component *m_component; - UserInfo m_uinfo; + std::string m_token; }; } diff --git a/spectrum/src/frontends/slack/SlackInstallation.cpp b/spectrum/src/frontends/slack/SlackInstallation.cpp index 6f3eae73..2d67ca30 100644 --- a/spectrum/src/frontends/slack/SlackInstallation.cpp +++ b/spectrum/src/frontends/slack/SlackInstallation.cpp @@ -22,7 +22,6 @@ #include "SlackFrontend.h" #include "SlackUser.h" #include "SlackRTM.h" -#include "SlackAPI.h" #include "transport/Transport.h" #include "transport/HTTPRequest.h" @@ -40,40 +39,46 @@ DEFINE_LOGGER(logger, "SlackInstallation"); SlackInstallation::SlackInstallation(Component *component, StorageBackend *storageBackend, UserInfo uinfo) : m_uinfo(uinfo) { m_component = component; m_storageBackend = storageBackend; - m_api = new SlackAPI(component, uinfo); - - m_api->usersList(boost::bind(&SlackInstallation::handleUsersList, this, _1, _2, _3, _4)); -// m_rtm = new SlackRTM(component, storageBackend, uinfo); + m_rtm = new SlackRTM(component, storageBackend, uinfo); + m_rtm->onRTMStarted.connect(boost::bind(&SlackInstallation::handleRTMStarted, this)); + m_rtm->onMessageReceived.connect(boost::bind(&SlackInstallation::handleMessageReceived, this, _1, _2, _3)); } SlackInstallation::~SlackInstallation() { -// delete m_rtm; - delete m_api; + delete m_rtm; +} + +void SlackInstallation::handleMessageReceived(const std::string &channel, const std::string &user, const std::string &message) { + if (m_ownerChannel == channel) { + LOG4CXX_INFO(logger, "Owner message received " << channel << " " << user << " " << message); + } } void SlackInstallation::handleImOpen(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data) { - std::string channel = m_api->getChannelId(req, ok, resp, data); - LOG4CXX_INFO(logger, "Opened channel with team owner: " << channel); + m_ownerChannel = m_rtm->getAPI()->getChannelId(req, ok, resp, data); + LOG4CXX_INFO(logger, "Opened channel with team owner: " << m_ownerChannel); std::string msg; - msg = "Hi, It seems you have authorized Spectrum 2 transport for your team. " - "As a team owner, you should now configure it. You should provide username and " - "password you want to use to connect your team to legacy network of your choice."; - m_api->sendMessage("Spectrum 2", channel, msg); + msg = "Hi, it seems you have enabled Spectrum 2 transport for your Team. As a Team owner, you should now configure it."; + m_rtm->sendMessage(m_ownerChannel, msg); - msg = "You can do it by typing \".spectrum2 register \". Password may be optional."; - m_api->sendMessage("Spectrum 2", channel, msg); - - msg = "For example to connect the Freenode IRC network, just type \".spectrum2 register irc.freenode.net\"."; - m_api->sendMessage("Spectrum 2", channel, msg); + msg = "To configure IRC network you want to connect to, type: \".spectrum2 register @\". For example for Freenode, the command looks like \".spectrum2 register MySlackBot@irc.freenode.net\"."; + m_rtm->sendMessage(m_ownerChannel, msg); } -void SlackInstallation::handleUsersList(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data) { - std::string ownerId = m_api->getOwnerId(req, ok, resp, data); - LOG4CXX_INFO(logger, "Team owner ID is " << ownerId); +void SlackInstallation::handleRTMStarted() { + std::string ownerId; + std::map &users = m_rtm->getUsers(); + for (std::map::iterator it = users.begin(); it != users.end(); it++) { + SlackUserInfo &info = it->second; + if (info.isPrimaryOwner) { + ownerId = it->first; + break; + } + } - m_api->imOpen(ownerId, boost::bind(&SlackInstallation::handleImOpen, this, _1, _2, _3, _4)); + m_rtm->getAPI()->imOpen(ownerId, boost::bind(&SlackInstallation::handleImOpen, this, _1, _2, _3, _4)); } diff --git a/spectrum/src/frontends/slack/SlackInstallation.h b/spectrum/src/frontends/slack/SlackInstallation.h index 73756d92..a126547a 100644 --- a/spectrum/src/frontends/slack/SlackInstallation.h +++ b/spectrum/src/frontends/slack/SlackInstallation.h @@ -35,7 +35,6 @@ class Component; class StorageBackend; class HTTPRequest; class SlackRTM; -class SlackAPI; class SlackInstallation { public: @@ -46,7 +45,8 @@ class SlackInstallation { boost::signal onInstallationDone; private: - void handleUsersList(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data); + void handleRTMStarted(); + void handleMessageReceived(const std::string &channel, const std::string &user, const std::string &message); void handleImOpen(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data); private: @@ -55,7 +55,7 @@ class SlackInstallation { UserInfo m_uinfo; std::string m_ownerName; SlackRTM *m_rtm; - SlackAPI *m_api; + std::string m_ownerChannel; }; } diff --git a/spectrum/src/frontends/slack/SlackRTM.cpp b/spectrum/src/frontends/slack/SlackRTM.cpp index 1d12860e..1a5e72f0 100644 --- a/spectrum/src/frontends/slack/SlackRTM.cpp +++ b/spectrum/src/frontends/slack/SlackRTM.cpp @@ -25,9 +25,11 @@ #include "transport/Transport.h" #include "transport/HTTPRequest.h" #include "transport/Util.h" +#include "transport/WebSocketClient.h" #include #include +#include #include #include @@ -38,29 +40,67 @@ DEFINE_LOGGER(logger, "SlackRTM"); SlackRTM::SlackRTM(Component *component, StorageBackend *storageBackend, UserInfo uinfo) : m_uinfo(uinfo) { m_component = component; m_storageBackend = storageBackend; + m_counter = 0; + m_client = new WebSocketClient(component); + m_client->onPayloadReceived.connect(boost::bind(&SlackRTM::handlePayloadReceived, this, _1)); + m_pingTimer = m_component->getNetworkFactories()->getTimerFactory()->createTimer(20000); + m_pingTimer->onTick.connect(boost::bind(&SlackRTM::sendPing, this)); + int type = (int) TYPE_STRING; + m_storageBackend->getUserSetting(m_uinfo.id, "bot_token", type, m_token); -#if HAVE_SWIFTEN_3 - Swift::TLSOptions o; -#endif - Swift::PlatformTLSFactories *m_tlsFactory = new Swift::PlatformTLSFactories(); -#if HAVE_SWIFTEN_3 - m_tlsConnectionFactory = new Swift::TLSConnectionFactory(m_tlsFactory->getTLSContextFactory(), component->getNetworkFactories()->getConnectionFactory(), o); -#else - m_tlsConnectionFactory = new Swift::TLSConnectionFactory(m_tlsFactory->getTLSContextFactory(), component->getNetworkFactories()->getConnectionFactory()); -#endif - + m_api = new SlackAPI(component, m_token); std::string url = "https://slack.com/api/rtm.start?"; - url += "token=" + Util::urlencode(m_uinfo.encoding); + url += "token=" + Util::urlencode(m_token); -// HTTPRequest *req = new HTTPRequest(); -// req->GET(THREAD_POOL(m_component), url, -// boost::bind(&SlackRTM::handleRTMStart, this, _1, _2, _3, _4)); + HTTPRequest *req = new HTTPRequest(THREAD_POOL(m_component), HTTPRequest::Get, url, boost::bind(&SlackRTM::handleRTMStart, this, _1, _2, _3, _4)); + req->execute(); } SlackRTM::~SlackRTM() { + delete m_client; + delete m_api; + m_pingTimer->stop(); +} +#define STORE_STRING(FROM, NAME) rapidjson::Value &NAME##_tmp = FROM[#NAME]; \ + if (!NAME##_tmp.IsString()) { \ + LOG4CXX_ERROR(logger, "No '" << #NAME << "' string in the reply."); \ + LOG4CXX_ERROR(logger, payload); \ + return; \ + } \ + std::string NAME = NAME##_tmp.GetString(); + +void SlackRTM::handlePayloadReceived(const std::string &payload) { + rapidjson::Document d; + if (d.Parse<0>(payload.c_str()).HasParseError()) { + LOG4CXX_ERROR(logger, "Error while parsing JSON"); + LOG4CXX_ERROR(logger, payload); + return; + } + + STORE_STRING(d, type); + + if (type == "message") { + STORE_STRING(d, channel); + STORE_STRING(d, user); + STORE_STRING(d, text); + onMessageReceived(channel, user, text); + } +} + +void SlackRTM::sendMessage(const std::string &channel, const std::string &message) { + m_counter++; + std::string msg = "{\"id\": " + boost::lexical_cast(m_counter) + ", \"type\": \"message\", \"channel\":\"" + channel + "\", \"text\":\"" + message + "\"}"; + m_client->write(msg); +} + +void SlackRTM::sendPing() { + m_counter++; + std::string msg = "{\"id\": " + boost::lexical_cast(m_counter) + ", \"type\": \"ping\"}"; + m_client->write(msg); + m_pingTimer->start(); } void SlackRTM::handleRTMStart(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data) { @@ -77,51 +117,19 @@ void SlackRTM::handleRTMStart(HTTPRequest *req, bool ok, rapidjson::Document &re return; } + SlackAPI::getSlackChannelInfo(req, ok, resp, data, m_channels); + SlackAPI::getSlackImInfo(req, ok, resp, data, m_ims); + SlackAPI::getSlackUserInfo(req, ok, resp, data, m_users); + std::string u = url.GetString(); LOG4CXX_INFO(logger, "Started RTM, WebSocket URL is " << u); + LOG4CXX_INFO(logger, data); - u = u.substr(6); - m_host = u.substr(0, u.find("/")); - m_path = u.substr(u.find("/")); + m_client->connectServer(u); + m_pingTimer->start(); - LOG4CXX_INFO(logger, "Starting DNS query for " << m_host << " " << m_path); - m_dnsQuery = m_component->getNetworkFactories()->getDomainNameResolver()->createAddressQuery(m_host); - m_dnsQuery->onResult.connect(boost::bind(&SlackRTM::handleDNSResult, this, _1, _2)); - m_dnsQuery->run(); + onRTMStarted(); } -void SlackRTM::handleDataRead(boost::shared_ptr data) { - LOG4CXX_INFO(logger, "data read"); - std::string d = Swift::safeByteArrayToString(*data); - LOG4CXX_INFO(logger, d); -} - -void SlackRTM::handleConnected(bool error) { - if (error) { - LOG4CXX_ERROR(logger, "Connection to " << m_host << " failed"); - return; - } - - LOG4CXX_INFO(logger, "Connected to " << m_host); - - std::string req = ""; - req += "GET " + m_path + " HTTP/1.1\r\n"; - req += "Host: " + m_host + ":443\r\n"; - req += "Upgrade: websocket\r\n"; - req += "Connection: Upgrade\r\n"; - req += "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"; - req += "Sec-WebSocket-Version: 13\r\n"; - req += "\r\n"; - - m_conn->write(Swift::createSafeByteArray(req)); - -} - -void SlackRTM::handleDNSResult(const std::vector &addrs, boost::optional) { - m_conn = m_tlsConnectionFactory->createConnection(); - m_conn->onDataRead.connect(boost::bind(&SlackRTM::handleDataRead, this, _1)); - m_conn->onConnectFinished.connect(boost::bind(&SlackRTM::handleConnected, this, _1)); - m_conn->connect(Swift::HostAddressPort(addrs[0], 443)); -} } diff --git a/spectrum/src/frontends/slack/SlackRTM.h b/spectrum/src/frontends/slack/SlackRTM.h index 26b6c4bf..6e1b6d34 100644 --- a/spectrum/src/frontends/slack/SlackRTM.h +++ b/spectrum/src/frontends/slack/SlackRTM.h @@ -20,6 +20,8 @@ #pragma once +#include "SlackAPI.h" + #include "transport/StorageBackend.h" #include "rapidjson/document.h" @@ -31,6 +33,7 @@ #include #include #include +#include "Swiften/Network/Timer.h" #include "Swiften/Version.h" #define HAVE_SWIFTEN_3 (SWIFTEN_VERSION >= 0x030000) @@ -50,6 +53,8 @@ namespace Transport { class Component; class StorageBackend; class HTTPRequest; +class WebSocketClient; +class SlackAPI; class SlackRTM { public: @@ -57,21 +62,40 @@ class SlackRTM { virtual ~SlackRTM(); + void sendPing(); + + void sendMessage(const std::string &channel, const std::string &message); + + boost::signal onRTMStarted; + + std::map &getUsers() { + return m_users; + } + + SlackAPI *getAPI() { + return m_api; + } + + boost::signal onMessageReceived; + private: - void handleDNSResult(const std::vector&, boost::optional); - void handleDataRead(boost::shared_ptr data); - void handleConnected(bool error); + void handlePayloadReceived(const std::string &payload); void handleRTMStart(HTTPRequest *req, bool ok, rapidjson::Document &resp, const std::string &data); + private: + std::map m_channels; + std::map m_ims; + std::map m_users; + private: Component *m_component; StorageBackend *m_storageBackend; UserInfo m_uinfo; - boost::shared_ptr m_dnsQuery; - boost::shared_ptr m_conn; - Swift::TLSConnectionFactory *m_tlsConnectionFactory; - std::string m_host; - std::string m_path; + WebSocketClient *m_client; + std::string m_token; + unsigned long m_counter; + Swift::Timer::ref m_pingTimer; + SlackAPI *m_api; }; } diff --git a/src/WebSocketClient.cpp b/src/WebSocketClient.cpp new file mode 100644 index 00000000..b18b5da8 --- /dev/null +++ b/src/WebSocketClient.cpp @@ -0,0 +1,193 @@ +/** + * XMPP - libpurple transport + * + * Copyright (C) 2009, Jan Kaluza + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA + */ + +#include "transport/WebSocketClient.h" +#include "transport/Transport.h" +#include "transport/Util.h" +#include "transport/Logging.h" + +#include +#include +#include +#include + +namespace Transport { + +DEFINE_LOGGER(logger, "WebSocketClient"); + +WebSocketClient::WebSocketClient(Component *component) { + m_component = component; + m_upgraded = false; + +#if HAVE_SWIFTEN_3 + Swift::TLSOptions o; +#endif + m_tlsFactory = new Swift::PlatformTLSFactories(); +#if HAVE_SWIFTEN_3 + m_tlsConnectionFactory = new Swift::TLSConnectionFactory(m_tlsFactory->getTLSContextFactory(), component->getNetworkFactories()->getConnectionFactory(), o); +#else + m_tlsConnectionFactory = new Swift::TLSConnectionFactory(m_tlsFactory->getTLSContextFactory(), component->getNetworkFactories()->getConnectionFactory()); +#endif +} + +WebSocketClient::~WebSocketClient() { + if (m_conn) { + m_conn->onDataRead.disconnect(boost::bind(&WebSocketClient::handleDataRead, this, _1)); + m_conn->disconnect(); + } + + delete m_tlsFactory; + delete m_tlsConnectionFactory; +} + +void WebSocketClient::connectServer(const std::string &url) { + std::string u = url.substr(6); + m_host = u.substr(0, u.find("/")); + m_path = u.substr(u.find("/")); + + LOG4CXX_INFO(logger, "Starting DNS query for " << m_host << " " << m_path); + m_dnsQuery = m_component->getNetworkFactories()->getDomainNameResolver()->createAddressQuery(m_host); + m_dnsQuery->onResult.connect(boost::bind(&WebSocketClient::handleDNSResult, this, _1, _2)); + m_dnsQuery->run(); +} + +void WebSocketClient::write(const std::string &data) { + if (!m_conn) { + return; + } + + uint8_t opcode = 129; // UTF8 + if (data.empty()) { + LOG4CXX_INFO(logger, "pong"); + opcode = 138; // PONG + } + + // Mask the payload + char mask_bits[4] = {0x11, 0x22, 0x33, 0x44}; + std::string payload = data; + for (size_t i = 0; i < data.size(); i++ ) { + payload[i] = payload[i] ^ mask_bits[i&3]; + } + + if (data.size() <= 125) { + uint8_t size7 = data.size() + 128; // Mask bit + m_conn->write(Swift::createSafeByteArray(std::string((char *) &opcode, 1) + + std::string((char *) &size7, 1) + + std::string((char *) &mask_bits[0], 4) + + payload)); + } + else { + uint8_t size7 = 126 + 128; // Mask bit + uint16_t size16 = data.size(); + size16 = htons(size16); + m_conn->write(Swift::createSafeByteArray(std::string((char *) &opcode, 1) + + std::string((char *) &size7, 1) + + std::string((char *) &size16, 2) + + std::string((char *) &mask_bits[0], 4) + + payload)); + } + + LOG4CXX_INFO(logger, "> " << data); +} + +void WebSocketClient::handleDataRead(boost::shared_ptr data) { + std::string d = Swift::safeByteArrayToString(*data); + m_buffer += d; + + if (!m_upgraded) { + if (m_buffer.find("\r\n\r\n") != std::string::npos) { + m_buffer.erase(0, m_buffer.find("\r\n\r\n") + 4); + m_upgraded = true; + } + else { + return; + } + } + + while (m_buffer.size() > 0) { + if (m_buffer.size() >= 2) { + uint8_t opcode = *((uint8_t *) &m_buffer[0]) & 0xf; + uint8_t size7 = *((uint8_t *) &m_buffer[1]); + uint16_t size16 = 0; + int header_size = 2; + if (size7 == 126) { + if (m_buffer.size() >= 4) { + size16 = *((uint16_t *) &m_buffer[2]); + size16 = ntohs(size16); + header_size += 2; + } + else { + return; + } + } + +// if (opcode == 9) { +// write(""); +// } + + unsigned int size = (size16 == 0 ? size7 : size16); + if (m_buffer.size() >= size + header_size) { + std::string payload = m_buffer.substr(header_size, size); + LOG4CXX_INFO(logger, "< " << payload); + onPayloadReceived(payload); + m_buffer.erase(0, size + header_size); + + } + else if (size == 0) { + m_buffer.erase(0, header_size); + } + else { + return; + } + } + else { + return; + } + } +} + +void WebSocketClient::handleConnected(bool error) { + if (error) { + LOG4CXX_ERROR(logger, "Connection to " << m_host << " failed"); + return; + } + + LOG4CXX_INFO(logger, "Connected to " << m_host); + + std::string req = ""; + req += "GET " + m_path + " HTTP/1.1\r\n"; + req += "Host: " + m_host + ":443\r\n"; + req += "Upgrade: websocket\r\n"; + req += "Connection: Upgrade\r\n"; + req += "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"; + req += "Sec-WebSocket-Version: 13\r\n"; + req += "\r\n"; + + m_conn->write(Swift::createSafeByteArray(req)); +} + +void WebSocketClient::handleDNSResult(const std::vector &addrs, boost::optional) { + m_conn = m_tlsConnectionFactory->createConnection(); + m_conn->onDataRead.connect(boost::bind(&WebSocketClient::handleDataRead, this, _1)); + m_conn->onConnectFinished.connect(boost::bind(&WebSocketClient::handleConnected, this, _1)); + m_conn->connect(Swift::HostAddressPort(addrs[0], 443)); +} + +}