package smb

import (
	"context"
	"fmt"
	"net"
	"strconv"
	"sync/atomic"

	"github.com/hirochachacha/go-smb2"
	"github.com/restic/restic/internal/debug"
)

// Parts of this code have been adapted from Rclone (https://github.com/rclone)
// Copyright (C) 2012 by Nick Craig-Wood http://www.craig-wood.com/nick/

// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:

// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

// conn encapsulates a SMB client and corresponding SMB client
type conn struct {
	conn       *net.Conn
	smbSession *smb2.Session
	smbShare   *smb2.Share
	shareName  string
}

// Closes the connection
func (c *conn) close() (err error) {
	if c.smbShare != nil {
		err = c.smbShare.Umount()
	}
	sessionLogoffErr := c.smbSession.Logoff()
	if err != nil {
		return err
	}
	return sessionLogoffErr
}

// True if it's closed
func (c *conn) closed() bool {
	var nopErr error
	if c.smbShare != nil {
		// stat the current directory
		_, nopErr = c.smbShare.Stat(".")
	} else {
		// list the shares
		_, nopErr = c.smbSession.ListSharenames()
	}
	return nopErr == nil
}

// Show that we are using a SMB session
//
// Call removeSession() when done
func (b *Backend) addSession() {
	atomic.AddInt32(&b.sessions, 1)
}

// Show the SMB session is no longer in use
func (b *Backend) removeSession() {
	atomic.AddInt32(&b.sessions, -1)
}

// getSessions shows whether there are any sessions in use
func (b *Backend) getSessions() int32 {
	return atomic.LoadInt32(&b.sessions)
}

// dial starts a client connection to the given SMB server. It is a
// convenience function that connects to the given network address,
// initiates the SMB handshake, and then returns a session for SMB communication.
func (b *Backend) dial(ctx context.Context, network, addr string) (*conn, error) {
	dialer := net.Dialer{}
	tconn, err := dialer.Dial(network, addr)
	if err != nil {
		return nil, err
	}
	var clientID [16]byte
	if b.ClientGUID != "" {
		copy(clientID[:], []byte(b.ClientGUID))
	}

	d := &smb2.Dialer{
		Negotiator: smb2.Negotiator{
			RequireMessageSigning: b.RequireMessageSigning,
			SpecifiedDialect:      b.Dialect,
			ClientGuid:            clientID,
		},
		Initiator: &smb2.NTLMInitiator{
			User:     b.User,
			Password: b.Password.Unwrap(),
			Domain:   b.Domain,
		},
	}

	session, err := d.DialContext(ctx, tconn)
	if err != nil {
		return nil, err
	}

	return &conn{
		smbSession: session,
		conn:       &tconn,
	}, nil
}

// Open a new connection to the SMB server.
func (b *Backend) newConnection(share string) (c *conn, err error) {
	// As we are pooling these connections we need to decouple
	// them from the current context
	ctx := context.Background()

	c, err = b.dial(ctx, "tcp", b.Host+":"+strconv.Itoa(b.Port))
	if err != nil {
		return nil, fmt.Errorf("couldn't connect SMB: %w", err)
	}

	if share != "" {
		// mount the specified share as well if user requested
		c.smbShare, err = c.smbSession.Mount(share)
		if err != nil {
			_ = c.smbSession.Logoff()
			return nil, fmt.Errorf("couldn't initialize SMB: %w", err)
		}
		c.smbShare = c.smbShare.WithContext(ctx)
	}

	return c, nil
}

// Ensure the specified share is mounted or the session is unmounted
func (c *conn) mountShare(share string) (err error) {
	if c.shareName == share {
		return nil
	}
	if c.smbShare != nil {
		err = c.smbShare.Umount()
		c.smbShare = nil
	}
	if err != nil {
		return
	}
	if share != "" {
		c.smbShare, err = c.smbSession.Mount(share)
		if err != nil {
			return
		}
	}
	c.shareName = share
	return nil
}

// Get a SMB connection from the pool, or open a new one
func (b *Backend) getConnection(ctx context.Context, share string) (c *conn, err error) {
	b.poolMu.Lock()
	for len(b.pool) > 0 {
		c = b.pool[0]
		b.pool = b.pool[1:]
		err = c.mountShare(share)
		if err == nil {
			break
		}
		debug.Log("Discarding unusable SMB connection: %v", err)
		c = nil
	}
	b.poolMu.Unlock()
	if c != nil {
		return c, nil
	}
	c, err = b.newConnection(share)
	return c, err
}

// Return a SMB connection to the pool
func (b *Backend) putConnection(c *conn) {
	var nopErr error
	if c.smbShare != nil {
		// stat the current directory
		_, nopErr = c.smbShare.Stat(".")
	} else {
		// list the shares
		_, nopErr = c.smbSession.ListSharenames()
	}
	if nopErr != nil {
		debug.Log("Connection failed, closing: %v", nopErr)
		_ = c.close()
		return
	}

	b.poolMu.Lock()
	b.pool = append(b.pool, c)
	b.drain.Reset(b.Config.IdleTimeout) // nudge on the pool emptying timer
	b.poolMu.Unlock()
}

// Drain the pool of any connections
func (b *Backend) drainPool() (err error) {
	b.poolMu.Lock()
	defer b.poolMu.Unlock()
	if sessions := b.getSessions(); sessions != 0 {
		debug.Log("Not closing %d unused connections as %d sessions active", len(b.pool), sessions)
		b.drain.Reset(b.Config.IdleTimeout) // nudge on the pool emptying timer
		return nil
	}
	if b.Config.IdleTimeout > 0 {
		b.drain.Stop()
	}
	if len(b.pool) != 0 {
		debug.Log("Closing %d unused connections", len(b.pool))
	}
	for i, c := range b.pool {
		if !c.closed() {
			cErr := c.close()
			if cErr != nil {
				err = cErr
			}
		}
		b.pool[i] = nil
	}
	b.pool = nil
	return err
}