/* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "Swiften/SwiftenCompat.h" #include #if (SWIFTEN_VERSION >= 0x030000) #include #endif namespace Swift { ServerFromClientSession::ServerFromClientSession( const std::string& id, SWIFTEN_SHRPTR_NAMESPACE::shared_ptr connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, UserRegistry* userRegistry, XMLParserFactory* factory, Swift::JID remoteJID) : Session(connection, payloadParserFactories, payloadSerializers, factory), id_(id), userRegistry_(userRegistry), authenticated_(false), initialized(false), allowSASLEXTERNAL(false), tlsLayer(0), tlsConnected(false) { setRemoteJID(remoteJID); } ServerFromClientSession::~ServerFromClientSession() { if (tlsLayer) { delete tlsLayer; } } void ServerFromClientSession::handlePasswordValid() { if (!isInitialized()) { getXMPPLayer()->writeElement(SWIFTEN_SHRPTR_NAMESPACE::shared_ptr(new AuthSuccess())); authenticated_ = true; getXMPPLayer()->resetParser(); } } void ServerFromClientSession::handlePasswordInvalid(const std::string &error) { if (!isInitialized()) { getXMPPLayer()->writeElement(SWIFTEN_SHRPTR_NAMESPACE::shared_ptr(new AuthFailure)); if (!error.empty()) { SWIFTEN_SHRPTR_NAMESPACE::shared_ptr msg(new StreamError(StreamError::UndefinedCondition, error)); getXMPPLayer()->writeElement(msg); } finishSession(AuthenticationFailedError); } } #if (SWIFTEN_VERSION >= 0x030000) void ServerFromClientSession::handleElement(SWIFTEN_SHRPTR_NAMESPACE::shared_ptr element) { #else void ServerFromClientSession::handleElement(SWIFTEN_SHRPTR_NAMESPACE::shared_ptr element) { #endif if (isInitialized()) { onElementReceived(element); } else { if (AuthRequest* authRequest = dynamic_cast(element.get())) { if (authRequest->getMechanism() == "PLAIN" || (allowSASLEXTERNAL && authRequest->getMechanism() == "EXTERNAL")) { if (authRequest->getMechanism() == "EXTERNAL") { getXMPPLayer()->writeElement(SWIFTEN_SHRPTR_NAMESPACE::shared_ptr(new AuthSuccess())); authenticated_ = true; getXMPPLayer()->resetParser(); } else { PLAINMessage plainMessage(authRequest->getMessage() ? *authRequest->getMessage() : createSafeByteArray("")); user_ = plainMessage.getAuthenticationID(); userRegistry_->isValidUserPassword(JID(plainMessage.getAuthenticationID(), getLocalJID().getDomain()), this, plainMessage.getPassword()); } } else { getXMPPLayer()->writeElement(SWIFTEN_SHRPTR_NAMESPACE::shared_ptr(new AuthFailure)); finishSession(NoSupportedAuthMechanismsError); } } else if (dynamic_cast(element.get()) != NULL) { getXMPPLayer()->writeElement(SWIFTEN_SHRPTR_NAMESPACE::shared_ptr(new TLSProceed)); getStreamStack()->addLayer(tlsLayer); tlsLayer->connect(); getXMPPLayer()->resetParser(); } else if (IQ* iq = dynamic_cast(element.get())) { if (SWIFTEN_SHRPTR_NAMESPACE::shared_ptr resourceBind = iq->getPayload()) { std::string bucket = "abcdefghijklmnopqrstuvwxyz"; std::string uuid; for (int i = 0; i < 10; i++) { uuid += bucket[rand() % bucket.size()]; } setRemoteJID(JID(user_, getLocalJID().getDomain(), uuid)); SWIFTEN_SHRPTR_NAMESPACE::shared_ptr resultResourceBind(new ResourceBind()); resultResourceBind->setJID(getRemoteJID()); getXMPPLayer()->writeElement(IQ::createResult(JID(), iq->getID(), resultResourceBind)); } else if (iq->getPayload()) { getXMPPLayer()->writeElement(IQ::createResult(getRemoteJID(), iq->getID())); setInitialized(); } } } } void ServerFromClientSession::handleStreamStart(const ProtocolHeader& incomingHeader) { setLocalJID(JID(incomingHeader.getTo())); ProtocolHeader header; header.setFrom(incomingHeader.getTo()); header.setID(id_); getXMPPLayer()->writeHeader(header); SWIFTEN_SHRPTR_NAMESPACE::shared_ptr features(new StreamFeatures()); if (!authenticated_) { if (tlsLayer && !tlsConnected) { features->setHasStartTLS(); } features->addAuthenticationMechanism("PLAIN"); if (allowSASLEXTERNAL) { features->addAuthenticationMechanism("EXTERNAL"); } } else { features->setHasResourceBind(); features->setHasSession(); } getXMPPLayer()->writeElement(features); } void ServerFromClientSession::setInitialized() { initialized = true; onSessionStarted(); } void ServerFromClientSession::setAllowSASLEXTERNAL() { allowSASLEXTERNAL = true; } void ServerFromClientSession::handleSessionFinished(const boost::optional&) { userRegistry_->stopLogin(JID(user_, getLocalJID().getDomain()), this); } void ServerFromClientSession::addTLSEncryption(TLSServerContextFactory* tlsContextFactory, CertificateWithKey::ref cert) { tlsLayer = new TLSServerLayer(tlsContextFactory); if (!tlsLayer->setServerCertificate(cert)) { // std::cout << "error\n"; // TODO: // onClosed(SWIFTEN_SHRPTR_NAMESPACE::shared_ptr(new Error(Error::InvalidTLSCertificateError))); } else { tlsLayer->onError.connect(boost::bind(&ServerFromClientSession::handleTLSError, this)); tlsLayer->onConnected.connect(boost::bind(&ServerFromClientSession::handleTLSConnected, this)); // getStreamStack()->addLayer(tlsLayer); // tlsLayer->onError.connect(boost::bind(&BasicSessionStream::handleTLSError, this)); // tlsLayer->onConnected.connect(boost::bind(&BasicSessionStream::handleTLSConnected, this)); // tlsLayer->connect(); } } }