package sftp

import (
	"encoding"
	"fmt"
	"sync"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
)

type _testSender struct {
	sent chan encoding.BinaryMarshaler
}

func newTestSender() *_testSender {
	return &_testSender{make(chan encoding.BinaryMarshaler)}
}

func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error {
	s.sent <- p
	return nil
}

type fakepacket uint32

func (fakepacket) MarshalBinary() ([]byte, error) {
	return []byte{}, nil
}

func (fakepacket) UnmarshalBinary([]byte) error {
	return nil
}

func (f fakepacket) id() uint32 {
	return uint32(f)
}

type pair struct {
	in  fakepacket
	out fakepacket
}

// basic test
var ttable1 = []pair{
	pair{fakepacket(0), fakepacket(0)},
	pair{fakepacket(1), fakepacket(1)},
	pair{fakepacket(2), fakepacket(2)},
	pair{fakepacket(3), fakepacket(3)},
}

// outgoing packets out of order
var ttable2 = []pair{
	pair{fakepacket(0), fakepacket(0)},
	pair{fakepacket(1), fakepacket(4)},
	pair{fakepacket(2), fakepacket(1)},
	pair{fakepacket(3), fakepacket(3)},
	pair{fakepacket(4), fakepacket(2)},
}

// incoming packets out of order
var ttable3 = []pair{
	pair{fakepacket(2), fakepacket(0)},
	pair{fakepacket(1), fakepacket(1)},
	pair{fakepacket(3), fakepacket(2)},
	pair{fakepacket(0), fakepacket(3)},
}

var tables = [][]pair{ttable1, ttable2, ttable3}

func TestPacketManager(t *testing.T) {
	sender := newTestSender()
	s := newPktMgr(sender)

	for i := range tables {
		table := tables[i]
		for _, p := range table {
			s.incomingPacket(p.in)
		}
		for _, p := range table {
			s.readyPacket(p.out)
		}
		for i := 0; i < len(table); i++ {
			pkt := <-sender.sent
			id := pkt.(fakepacket).id()
			assert.Equal(t, id, uint32(i))
		}
	}
	s.close()
}

func (p sshFxpRemovePacket) String() string {
	return fmt.Sprintf("RmPct:%d", p.ID)
}
func (p sshFxpOpenPacket) String() string {
	return fmt.Sprintf("OpPct:%d", p.ID)
}
func (p sshFxpWritePacket) String() string {
	return fmt.Sprintf("WrPct:%d", p.ID)
}
func (p sshFxpClosePacket) String() string {
	return fmt.Sprintf("ClPct:%d", p.ID)
}

// Test what happens when the pool processes a close packet on a file that it
// is still reading from.
func TestCloseOutOfOrder(t *testing.T) {
	packets := []requestPacket{
		&sshFxpRemovePacket{ID: 0, Filename: "foo"},
		&sshFxpOpenPacket{ID: 1},
		&sshFxpWritePacket{ID: 2, Handle: "foo"},
		&sshFxpWritePacket{ID: 3, Handle: "foo"},
		&sshFxpWritePacket{ID: 4, Handle: "foo"},
		&sshFxpWritePacket{ID: 5, Handle: "foo"},
		&sshFxpClosePacket{ID: 6, Handle: "foo"},
		&sshFxpRemovePacket{ID: 7, Filename: "foo"},
	}

	recvChan := make(chan requestPacket, len(packets)+1)
	sender := newTestSender()
	pktMgr := newPktMgr(sender)
	wg := sync.WaitGroup{}
	wg.Add(len(packets))
	runWorker := func(ch requestChan) {
		go func() {
			for pkt := range ch {
				if _, ok := pkt.(*sshFxpWritePacket); ok {
					// sleep to cause writes to come after close/remove
					time.Sleep(time.Millisecond)
				}
				pktMgr.working.Done()
				recvChan <- pkt
				wg.Done()
			}
		}()
	}
	pktChan := pktMgr.workerChan(runWorker)
	for _, p := range packets {
		pktChan <- p
	}
	wg.Wait()
	close(recvChan)
	received := []requestPacket{}
	for p := range recvChan {
		received = append(received, p)
	}
	if received[len(received)-2].id() != packets[len(packets)-2].id() {
		t.Fatal("Packets processed out of order1:", received, packets)
	}
	if received[len(received)-1].id() != packets[len(packets)-1].id() {
		t.Fatal("Packets processed out of order2:", received, packets)
	}
}