/* * Copyright (c) 2010 Remko Tronçon * Licensed under the GNU General Public License v3. * See Documentation/Licenses/GPLv3.txt for more information. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace Swift { ServerFromClientSession::ServerFromClientSession( const std::string& id, boost::shared_ptr connection, PayloadParserFactoryCollection* payloadParserFactories, PayloadSerializerCollection* payloadSerializers, UserRegistry* userRegistry) : Session(connection, payloadParserFactories, payloadSerializers), id_(id), userRegistry_(userRegistry), authenticated_(false), initialized(false), allowSASLEXTERNAL(false), tlsLayer(0), tlsConnected(false) { } ServerFromClientSession::~ServerFromClientSession() { std::cout << "DESTRUCTOR;\n"; userRegistry_->onPasswordValid.disconnect(boost::bind(&ServerFromClientSession::handlePasswordValid, this, _1)); userRegistry_->onPasswordInvalid.disconnect(boost::bind(&ServerFromClientSession::handlePasswordInvalid, this, _1)); if (tlsLayer) { delete tlsLayer; } } void ServerFromClientSession::handlePasswordValid(const std::string &user) { if (user != JID(user_, getLocalJID().getDomain()).toString()) return; if (!isInitialized()) { userRegistry_->onPasswordValid.disconnect(boost::bind(&ServerFromClientSession::handlePasswordValid, this, _1)); userRegistry_->onPasswordInvalid.disconnect(boost::bind(&ServerFromClientSession::handlePasswordInvalid, this, _1)); getXMPPLayer()->writeElement(boost::shared_ptr(new AuthSuccess())); authenticated_ = true; getXMPPLayer()->resetParser(); } } void ServerFromClientSession::handlePasswordInvalid(const std::string &user) { if (user != JID(user_, getLocalJID().getDomain()).toString() || authenticated_) return; if (!isInitialized()) { userRegistry_->onPasswordValid.disconnect(boost::bind(&ServerFromClientSession::handlePasswordValid, this, _1)); userRegistry_->onPasswordInvalid.disconnect(boost::bind(&ServerFromClientSession::handlePasswordInvalid, this, _1)); getXMPPLayer()->writeElement(boost::shared_ptr(new AuthFailure)); finishSession(AuthenticationFailedError); } } void ServerFromClientSession::handleElement(boost::shared_ptr element) { if (isInitialized()) { onElementReceived(element); } else { if (AuthRequest* authRequest = dynamic_cast(element.get())) { if (authRequest->getMechanism() == "PLAIN" || (allowSASLEXTERNAL && authRequest->getMechanism() == "EXTERNAL")) { if (authRequest->getMechanism() == "EXTERNAL") { getXMPPLayer()->writeElement(boost::shared_ptr(new AuthSuccess())); authenticated_ = true; getXMPPLayer()->resetParser(); } else { PLAINMessage plainMessage(authRequest->getMessage() ? *authRequest->getMessage() : createSafeByteArray("")); user_ = plainMessage.getAuthenticationID(); userRegistry_->onPasswordInvalid(JID(plainMessage.getAuthenticationID(), getLocalJID().getDomain()).toBare().toString()); userRegistry_->onPasswordValid.connect(boost::bind(&ServerFromClientSession::handlePasswordValid, this, _1)); userRegistry_->onPasswordInvalid.connect(boost::bind(&ServerFromClientSession::handlePasswordInvalid, this, _1)); if (userRegistry_->isValidUserPassword(JID(plainMessage.getAuthenticationID(), getLocalJID().getDomain()), plainMessage.getPassword())) { // we're waiting for usermanager signal now // authenticated_ = true; // getXMPPLayer()->resetParser(); } else { getXMPPLayer()->writeElement(boost::shared_ptr(new AuthFailure)); finishSession(AuthenticationFailedError); } } } else { getXMPPLayer()->writeElement(boost::shared_ptr(new AuthFailure)); finishSession(NoSupportedAuthMechanismsError); } } else if (dynamic_cast(element.get()) != NULL) { getXMPPLayer()->writeElement(boost::shared_ptr(new TLSProceed)); getStreamStack()->addLayer(tlsLayer); tlsLayer->connect(); getXMPPLayer()->resetParser(); } else if (IQ* iq = dynamic_cast(element.get())) { if (boost::shared_ptr resourceBind = iq->getPayload()) { std::string bucket = "abcdefghijklmnopqrstuvwxyz"; std::string uuid; for (int i = 0; i < 10; i++) { uuid += bucket[rand() % bucket.size()]; } setRemoteJID(JID(user_, getLocalJID().getDomain(), uuid)); boost::shared_ptr resultResourceBind(new ResourceBind()); resultResourceBind->setJID(getRemoteJID()); getXMPPLayer()->writeElement(IQ::createResult(JID(), iq->getID(), resultResourceBind)); } else if (iq->getPayload()) { getXMPPLayer()->writeElement(IQ::createResult(getRemoteJID(), iq->getID())); setInitialized(); } } } } void ServerFromClientSession::handleStreamStart(const ProtocolHeader& incomingHeader) { setLocalJID(JID("", incomingHeader.getTo())); ProtocolHeader header; header.setFrom(incomingHeader.getTo()); header.setID(id_); getXMPPLayer()->writeHeader(header); boost::shared_ptr features(new StreamFeatures()); if (!authenticated_) { if (tlsLayer && !tlsConnected) { features->setHasStartTLS(); } features->addAuthenticationMechanism("PLAIN"); if (allowSASLEXTERNAL) { features->addAuthenticationMechanism("EXTERNAL"); } } else { features->setHasResourceBind(); features->setHasSession(); } getXMPPLayer()->writeElement(features); } void ServerFromClientSession::setInitialized() { initialized = true; onSessionStarted(); } void ServerFromClientSession::setAllowSASLEXTERNAL() { allowSASLEXTERNAL = true; } void ServerFromClientSession::addTLSEncryption(TLSServerContextFactory* tlsContextFactory, const PKCS12Certificate& cert) { tlsLayer = new TLSServerLayer(tlsContextFactory); if (!tlsLayer->setServerCertificate(cert)) { // std::cout << "error\n"; // TODO: // onClosed(boost::shared_ptr(new Error(Error::InvalidTLSCertificateError))); } else { tlsLayer->onError.connect(boost::bind(&ServerFromClientSession::handleTLSError, this)); tlsLayer->onConnected.connect(boost::bind(&ServerFromClientSession::handleTLSConnected, this)); // getStreamStack()->addLayer(tlsLayer); // tlsLayer->onError.connect(boost::bind(&BasicSessionStream::handleTLSError, this)); // tlsLayer->onConnected.connect(boost::bind(&BasicSessionStream::handleTLSConnected, this)); // tlsLayer->connect(); } } }