diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9c9555543..e515d2781 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,9 +1,10 @@ name: test on: # run tests on push to master, but not when other branches are pushed to + workflow_dispatch: {} push: branches: - - master + - 'feature/smb-test-setup' # run tests for all pull requests pull_request: @@ -26,6 +27,7 @@ jobs: go: 1.19.x os: macOS-latest test_fuse: false + test_smb: false - job_name: Linux go: 1.19.x @@ -82,6 +84,65 @@ jobs: chmod 755 $HOME/bin/rclone rm -rf rclone* + echo "install samba" + user="smbuser" + pass="mGoWwqvgdnwtmh07" + + if [ "$RUNNER_OS" == "macOS" ]; then + #NONINTERACTIVE=1 brew install samba + else + sudo apt-get update + sudo apt-get install samba -y + + echo "Allow Samba in firewall" + sudo ufw allow 'Samba' + + echo "modifying samba config" + echo '' | sudo tee -a /etc/samba/smb.conf + echo ' interfaces = 127.0.0.0/8 eth0' | sudo tee -a /etc/samba/smb.conf + echo ' bind interfaces only = yes' | sudo tee -a /etc/samba/smb.conf + echo '' | sudo tee -a /etc/samba/smb.conf + echo "[$user]" | sudo tee -a /etc/samba/smb.conf + echo ' comment = Samba on Ubuntu' | sudo tee -a /etc/samba/smb.conf + echo " path = /samba/$user" | sudo tee -a /etc/samba/smb.conf + echo ' browseable = yes' | sudo tee -a /etc/samba/smb.conf + echo ' read only = no' | sudo tee -a /etc/samba/smb.conf + echo ' force create mode = 0660' | sudo tee -a /etc/samba/smb.conf + echo ' force directory mode = 2770' | sudo tee -a /etc/samba/smb.conf + echo " valid users = $user" | sudo tee -a /etc/samba/smb.conf + + echo "restart services" + sudo systemctl restart smbd + sudo systemctl restart nmbd + + echo "create samba share directory" + sudo mkdir /samba + + echo "change sambashare group" + sudo chgrp sambashare /samba + + echo "add samba user" + sudo id -u "$user" &>/dev/null || sudo useradd -M -d "/samba/$user" -s /usr/sbin/nologin -G sambashare "$user" + + echo "create samba share user directory" + sudo mkdir "/samba/$user" + + echo "change samba share user directory ownership" + sudo chown "$user":sambashare "/samba/$user" + + echo "modify permissions on samba share user directory" + sudo chmod 2770 "/samba/$user" + + echo "change smb password" + (echo "$pass"; echo "$pass") | sudo smbpasswd -a "$user" + + echo "enable samba user" + sudo smbpasswd -e "$user" + + echo "restart services" + sudo systemctl restart smbd + sudo systemctl restart nmbd + fi # add $HOME/bin to path ($GOBIN was already added to the path by setup-go@v3) echo $HOME/bin >> $GITHUB_PATH if: matrix.os == 'ubuntu-latest' || matrix.os == 'macOS-latest' @@ -121,6 +182,26 @@ jobs: Invoke-WebRequest https://github.com/restic/test-assets/raw/master/libiconv-1.8-1-bin.zip -OutFile libiconv.zip unzip libiconv.zip + # Create new smbshare + $user="smbuser" + $pass="mGoWwqvgdnwtmh07" + $SecurePassword = $pass | ConvertTo-SecureString -AsPlainText -Force + + echo "Create user" + New-LocalUser $user -Password $SecurePassword -FullName "SMB User" -Description "Account used for smb access." + + echo "Making user admin" + Add-LocalGroupMember -Group "Administrators" -Member "$user" + + $path="C:\$user" + mkdir $path + + echo "Create share" + New-SmbShare -Name $user -Path $path -FullAccess "Administrators" -EncryptData $True + + echo "Grant access to share" + Grant-SmbShareAccess -Name $user -AccountName $user -AccessRight Full -Force + # add $USERPROFILE/tar/bin to path echo $Env:USERPROFILE\tar\bin >> $Env:GITHUB_PATH if: matrix.os == 'windows-latest' @@ -135,6 +216,7 @@ jobs: - name: Run local Tests env: RESTIC_TEST_FUSE: ${{ matrix.test_fuse }} + RESTIC_TEST_SMB: ${{ matrix.test_smb }} run: | go test -cover ${{matrix.test_opts}} ./... diff --git a/cmd/restic/global.go b/cmd/restic/global.go index 8a4edf407..5b44041f2 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -25,6 +25,7 @@ import ( "github.com/restic/restic/internal/backend/retry" "github.com/restic/restic/internal/backend/s3" "github.com/restic/restic/internal/backend/sftp" + "github.com/restic/restic/internal/backend/smb" "github.com/restic/restic/internal/backend/swift" "github.com/restic/restic/internal/cache" "github.com/restic/restic/internal/debug" @@ -683,6 +684,80 @@ func parseConfig(loc location.Location, opts options.Options) (interface{}, erro debug.Log("opening rest repository at %#v", cfg) return cfg, nil + case "smb": + cfg := loc.Config.(smb.Config) + if err := opts.Apply(loc.Scheme, &cfg); err != nil { + return nil, err + } + if cfg.User == "" { + cfg.User = os.Getenv("RESTIC_SMB_USER") + } + + if cfg.Password.String() == "" { + cfg.Password = options.NewSecretString(os.Getenv("RESTIC_SMB_PASSWORD")) + } + + if cfg.Domain == "" { + cfg.Domain = os.Getenv("RESTIC_SMB_DOMAIN") + } + if cfg.Domain == "" { + cfg.Domain = smb.DefaultDomain + } + + //0 is an acceptable value for timeout, hence using -1 as the default unset value. + if cfg.IdleTimeout == nil { + it := os.Getenv("RESTIC_SMB_IDLETIMEOUTSECS") + if it == "" { + timeout := smb.DefaultIdleTimeout + cfg.IdleTimeout = &timeout + } else { + t, err := strconv.Atoi(it) + if err != nil { + return nil, err + } + timeout := (time.Duration(int64(t) * int64(time.Second))) + cfg.IdleTimeout = &timeout + } + } + + if cfg.Connections == 0 { + c := os.Getenv("RESTIC_SMB_CONNECTIONS") + if c == "" { + cfg.Connections = smb.DefaultConnections + } else { + con, err := strconv.Atoi(c) + if err != nil { + return nil, err + } + cfg.Connections = uint(con) + } + } + + if cfg.RequireMessageSigning == nil { + v := os.Getenv("RESTIC_SMB_REQUIRE_MESSAGESIGNING") + rms := strings.ToLower(v) == "true" + cfg.RequireMessageSigning = &rms + } + + if cfg.ClientGuid == "" { + c := os.Getenv("RESTIC_SMB_CLIENTGUID") + cfg.ClientGuid = c + } + + if cfg.Dialect == 0 { + d := os.Getenv("RESTIC_SMB_DIALECT") + if d != "" { + v, err := strconv.Atoi(d) + if err != nil { + return nil, err + } + cfg.Dialect = uint16(v) + } + } + + debug.Log("opening smb repository at %#v", cfg) + return cfg, nil + } return nil, errors.Fatalf("invalid backend: %q", loc.Scheme) @@ -717,6 +792,8 @@ func open(ctx context.Context, s string, gopts GlobalOptions, opts options.Optio be, err = local.Open(ctx, cfg.(local.Config)) case "sftp": be, err = sftp.Open(ctx, cfg.(sftp.Config)) + case "smb": + be, err = smb.Open(ctx, cfg.(smb.Config)) case "s3": be, err = s3.Open(ctx, cfg.(s3.Config), rt) case "gs": @@ -748,7 +825,7 @@ func open(ctx context.Context, s string, gopts GlobalOptions, opts options.Optio } } - if loc.Scheme == "local" || loc.Scheme == "sftp" { + if loc.Scheme == "local" || loc.Scheme == "sftp" || loc.Scheme == "smb" { // wrap the backend in a LimitBackend so that the throughput is limited be = limiter.LimitBackend(be, lim) } @@ -789,6 +866,8 @@ func create(ctx context.Context, s string, opts options.Options) (restic.Backend return local.Create(ctx, cfg.(local.Config)) case "sftp": return sftp.Create(ctx, cfg.(sftp.Config)) + case "smb": + return smb.Create(ctx, cfg.(smb.Config)) case "s3": return s3.Create(ctx, cfg.(s3.Config), rt) case "gs": diff --git a/go.mod b/go.mod index a172c6992..f6fae1dcc 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/go-ole/go-ole v1.2.6 github.com/google/go-cmp v0.5.9 github.com/hashicorp/golang-lru/v2 v2.0.1 + github.com/hirochachacha/go-smb2 v1.1.0 github.com/juju/ratelimit v1.0.2 github.com/klauspost/compress v1.15.15 github.com/kurin/blazer v0.5.4-0.20230113224640-3887e1ec64b5 @@ -24,6 +25,7 @@ require ( github.com/restic/chunker v0.4.0 github.com/spf13/cobra v1.6.1 github.com/spf13/pflag v1.0.5 + github.com/valyala/fastrand v1.1.0 golang.org/x/crypto v0.5.0 golang.org/x/net v0.5.0 golang.org/x/oauth2 v0.4.0 @@ -44,6 +46,7 @@ require ( github.com/dnaeon/go-vcr v1.2.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/felixge/fgprof v0.9.3 // indirect + github.com/geoffgarside/ber v1.1.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/pprof v0.0.0-20230111200839-76d1ae5aea2b // indirect diff --git a/go.sum b/go.sum index 08069a411..1a12f025c 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= +github.com/geoffgarside/ber v1.1.0 h1:qTmFG4jJbwiSzSXoNJeHcOprVzZ8Ulde2Rrrifu5U9w= +github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= @@ -95,6 +97,8 @@ github.com/googleapis/gax-go/v2 v2.7.0 h1:IcsPKeInNvYi7eqSaDjiZqDDKu5rsmunY0Y1Yu github.com/googleapis/gax-go/v2 v2.7.0/go.mod h1:TEop28CZZQ2y+c0VxMUmu1lV+fQx57QpBWsYpwqHJx8= github.com/hashicorp/golang-lru/v2 v2.0.1 h1:5pv5N1lT1fjLg2VQ5KWc7kmucp2x/kvFOnxuVTqZ6x4= github.com/hashicorp/golang-lru/v2 v2.0.1/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hirochachacha/go-smb2 v1.1.0 h1:b6hs9qKIql9eVXAiN0M2wSFY5xnhbHAQoCwRKbaRTZI= +github.com/hirochachacha/go-smb2 v1.1.0/go.mod h1:8F1A4d5EZzrGu5R7PU163UcMRDJQl4FtcxjBfsY8TZE= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -165,12 +169,15 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c h1:u6SKchux2yDvFQnDHS3lPnIRmfVJ5Sxy3ao2SIdysLQ= github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM= +github.com/valyala/fastrand v1.1.0 h1:f+5HkLW4rsgzdNoleUOB69hyT9IlD2ZQh9GyDMfb5G8= +github.com/valyala/fastrand v1.1.0/go.mod h1:HWqCzkrkg6QXT8V2EXWvXCoow7vLwOFN002oeRzjapQ= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= diff --git a/internal/backend/location/location.go b/internal/backend/location/location.go index a732233cc..c9bbf981f 100644 --- a/internal/backend/location/location.go +++ b/internal/backend/location/location.go @@ -12,6 +12,7 @@ import ( "github.com/restic/restic/internal/backend/rest" "github.com/restic/restic/internal/backend/s3" "github.com/restic/restic/internal/backend/sftp" + "github.com/restic/restic/internal/backend/smb" "github.com/restic/restic/internal/backend/swift" "github.com/restic/restic/internal/errors" ) @@ -41,6 +42,7 @@ var parsers = []parser{ {"swift", swift.ParseConfig, noPassword}, {"rest", rest.ParseConfig, rest.StripPassword}, {"rclone", rclone.ParseConfig, noPassword}, + {"smb", smb.ParseConfig, noPassword}, } // noPassword returns the repository location unchanged (there's no sensitive information there) diff --git a/internal/backend/smb/config.go b/internal/backend/smb/config.go new file mode 100644 index 000000000..9d63d3b96 --- /dev/null +++ b/internal/backend/smb/config.go @@ -0,0 +1,96 @@ +package smb + +import ( + "path" + "strconv" + "strings" + "time" + + "github.com/restic/restic/internal/errors" + "github.com/restic/restic/internal/options" +) + +// Config contains all configuration necessary to connect to an SMB server +type Config struct { + Address string + Port int + ShareName string + Path string + + Layout string `option:"layout" help:"use this backend directory layout (default: auto-detect)"` + Connections uint `option:"connections" help:"set a limit for the number of concurrent operations (default: 2)"` + IdleTimeout *time.Duration `option:"idle-timeout" help:"Max time in seconds before closing idle connections. If no connections have been returned to the connection pool in the time given, the connection pool will be emptied. Set to 0 to keep connections indefinitely.(default: 60)"` + RequireMessageSigning *bool `option:"require-message-signing" help:"Mandates message signing otherwise does not allow the connection. If this is false, messaging signing is just enabled and not enforced. (default: false)"` + Dialect uint16 `option:"dialect" help:"Force a specific dialect to be used. SMB311:785, SMB302:770, SMB300:768, SMB210:528, SMB202:514, SMB2:767. If unspecfied (0), following dialects are tried in order - SMB311, SMB302, SMB300, SMB210, SMB202 (default: 0)"` + ClientGuid string `option:"client-guid" help:"A 16-byte GUID to uniquely identify a client. If not specific a random GUID is used. (default: \"\")"` + + User string `option:"user"` + Password options.SecretString `option:"password"` + Domain string `option:"domain"` +} + +const ( + DefaultSmbPort int = 445 + DefaultDomain string = "WORKGROUP" + DefaultConnections uint = 2 + DefaultIdleTimeout time.Duration = 60 * time.Second +) + +// NewConfig returns a new Config with the default values filled in. +func NewConfig() Config { + return Config{ + Port: DefaultSmbPort, + } +} + +func init() { + options.Register("smb", Config{}) +} + +// ParseConfig parses the string s and extracts the s3 config. The two +// supported configuration formats are smb://address:port/sharename/directory and +// smb://address/sharename/directory in which case default port 445 is used. +// If no prefix is given the prefix "restic" will be used. +func ParseConfig(s string) (interface{}, error) { + switch { + case strings.HasPrefix(s, "smb://"): + s = s[6:] + case strings.HasPrefix(s, "smb:"): + s = s[4:] + default: + return nil, errors.New("smb: invalid format") + } + // use the first entry of the path as the endpoint and the + // remainder as bucket name and prefix + fullAddress, rest, _ := strings.Cut(s, "/") + address, portString, hasPort := strings.Cut(fullAddress, ":") + var port int + if !hasPort { + port = DefaultSmbPort + } else { + var err error + port, err = strconv.Atoi(portString) + if err != nil { + return nil, err + } + } + sharename, directory, _ := strings.Cut(rest, "/") + return createConfig(address, port, sharename, directory) +} + +func createConfig(address string, port int, sharename string, directory string) (interface{}, error) { + if address == "" { + return nil, errors.New("smb: invalid format, address not found") + } + + if directory != "" { + directory = path.Clean(directory) + } + + cfg := NewConfig() + cfg.Address = address + cfg.Port = port + cfg.ShareName = sharename + cfg.Path = directory + return cfg, nil +} diff --git a/internal/backend/smb/config_test.go b/internal/backend/smb/config_test.go new file mode 100644 index 000000000..678bb8db0 --- /dev/null +++ b/internal/backend/smb/config_test.go @@ -0,0 +1,51 @@ +package smb + +import ( + "strings" + "testing" +) + +var configTests = []struct { + s string + cfg Config +}{ + {"smb://shareaddress/sharename/directory", Config{ + Address: "shareaddress", + Port: DefaultSmbPort, + ShareName: "sharename", + Path: "directory", + }}, + {"smb://shareaddress:456/sharename/directory", Config{ + Address: "shareaddress", + Port: 456, + ShareName: "sharename", + Path: "directory", + }}, +} + +func TestParseConfig(t *testing.T) { + for i, test := range configTests { + cfg, err := ParseConfig(test.s) + if err != nil { + t.Errorf("test %d:%s failed: %v", i, test.s, err) + continue + } + + if cfg != test.cfg { + t.Errorf("test %d:\ninput:\n %s\n wrong config, want:\n %v\ngot:\n %v", + i, test.s, test.cfg, cfg) + continue + } + } +} + +func TestParseError(t *testing.T) { + const prefix = "smb: invalid format," + + for _, s := range []string{"", "/", "//", "/sharename/directory"} { + _, err := ParseConfig("smb://" + s) + if err == nil || !strings.HasPrefix(err.Error(), prefix) { + t.Errorf("expected %q, got %q", prefix, err) + } + } +} diff --git a/internal/backend/smb/conpool.go b/internal/backend/smb/conpool.go new file mode 100644 index 000000000..9855d0557 --- /dev/null +++ b/internal/backend/smb/conpool.go @@ -0,0 +1,230 @@ +package smb + +import ( + "context" + "fmt" + "net" + "strconv" + "sync/atomic" + + "github.com/hirochachacha/go-smb2" + "github.com/restic/restic/internal/debug" +) + +// conn encapsulates a SMB client and corresponding SMB client +type conn struct { + conn *net.Conn + smbSession *smb2.Session + smbShare *smb2.Share + shareName string +} + +// Closes the connection +func (c *conn) close() (err error) { + if c.smbShare != nil { + err = c.smbShare.Umount() + } + sessionLogoffErr := c.smbSession.Logoff() + if err != nil { + return err + } + return sessionLogoffErr +} + +// True if it's closed +func (c *conn) closed() bool { + var nopErr error + if c.smbShare != nil { + // stat the current directory + _, nopErr = c.smbShare.Stat(".") + } else { + // list the shares + _, nopErr = c.smbSession.ListSharenames() + } + return nopErr == nil +} + +// Show that we are using a SMB session +// +// Call removeSession() when done +func (b *Backend) addSession() { + atomic.AddInt32(&b.sessions, 1) +} + +// Show the SMB session is no longer in use +func (b *Backend) removeSession() { + atomic.AddInt32(&b.sessions, -1) +} + +// getSessions shows whether there are any sessions in use +func (b *Backend) getSessions() int32 { + return atomic.LoadInt32(&b.sessions) +} + +// dial starts a client connection to the given SMB server. It is a +// convenience function that connects to the given network address, +// initiates the SMB handshake, and then sets up a Client. +func (b *Backend) dial(ctx context.Context, network, addr string) (*conn, error) { + dialer := net.Dialer{} + tconn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + var clientId [16]byte + if b.ClientGuid != "" { + copy(clientId[:], []byte(b.ClientGuid)) + } + + rms := b.RequireMessageSigning != nil + if rms { + rms = *b.RequireMessageSigning + } + d := &smb2.Dialer{ + Negotiator: smb2.Negotiator{ + RequireMessageSigning: rms, + SpecifiedDialect: b.Dialect, + ClientGuid: clientId, + }, + Initiator: &smb2.NTLMInitiator{ + User: b.User, + Password: b.Password.Unwrap(), + Domain: b.Domain, + }, + } + + session, err := d.DialContext(ctx, tconn) + if err != nil { + return nil, err + } + + return &conn{ + smbSession: session, + conn: &tconn, + }, nil +} + +// Open a new connection to the SMB server. +func (b *Backend) newConnection(share string) (c *conn, err error) { + // As we are pooling these connections we need to decouple + // them from the current context + ctx := context.Background() + + c, err = b.dial(ctx, "tcp", b.Address+":"+strconv.Itoa(b.Port)) + if err != nil { + return nil, fmt.Errorf("couldn't connect SMB: %w", err) + } + + if share != "" { + // mount the specified share as well if user requested + c.smbShare, err = c.smbSession.Mount(share) + if err != nil { + _ = c.smbSession.Logoff() + return nil, fmt.Errorf("couldn't initialize SMB: %w", err) + } + c.smbShare = c.smbShare.WithContext(ctx) + } + + return c, nil +} + +// Ensure the specified share is mounted or the session is unmounted +func (c *conn) mountShare(share string) (err error) { + if c.shareName == share { + return nil + } + if c.smbShare != nil { + err = c.smbShare.Umount() + c.smbShare = nil + } + if err != nil { + return + } + if share != "" { + c.smbShare, err = c.smbSession.Mount(share) + if err != nil { + return + } + } + c.shareName = share + return nil +} + +// Get a SMB connection from the pool, or open a new one +func (b *Backend) getConnection(ctx context.Context, share string) (c *conn, err error) { + b.poolMu.Lock() + for len(b.pool) > 0 { + c = b.pool[0] + b.pool = b.pool[1:] + err = c.mountShare(share) + if err == nil { + break + } + debug.Log("Discarding unusable SMB connection: %v", err) + c = nil + } + b.poolMu.Unlock() + if c != nil { + return c, nil + } + c, err = b.newConnection(share) + return c, err +} + +// Return a SMB connection to the pool +// +// It nils the pointed to connection out so it can't be reused +func (b *Backend) putConnection(pc **conn) { + c := *pc + *pc = nil + + var nopErr error + if c.smbShare != nil { + // stat the current directory + _, nopErr = c.smbShare.Stat(".") + } else { + // list the shares + _, nopErr = c.smbSession.ListSharenames() + } + if nopErr != nil { + debug.Log("Connection failed, closing: %v", nopErr) + _ = c.close() + return + } + + b.poolMu.Lock() + b.pool = append(b.pool, c) + if b.Config.IdleTimeout != nil && *b.Config.IdleTimeout > 0 { + b.drain.Reset(*b.Config.IdleTimeout) // nudge on the pool emptying timer + } + b.poolMu.Unlock() +} + +// Drain the pool of any connections +func (b *Backend) drainPool() (err error) { + b.poolMu.Lock() + defer b.poolMu.Unlock() + if sessions := b.getSessions(); sessions != 0 { + debug.Log("Not closing %d unused connections as %d sessions active", len(b.pool), sessions) + if b.Config.IdleTimeout != nil && *b.Config.IdleTimeout > 0 { + b.drain.Reset(*b.Config.IdleTimeout) // nudge on the pool emptying timer + } + return nil + } + if b.Config.IdleTimeout != nil && *b.Config.IdleTimeout > 0 { + b.drain.Stop() + } + if len(b.pool) != 0 { + debug.Log("Closing %d unused connections", len(b.pool)) + } + for i, c := range b.pool { + if !c.closed() { + cErr := c.close() + if cErr != nil { + err = cErr + } + } + b.pool[i] = nil + } + b.pool = nil + return err +} diff --git a/internal/backend/smb/smb.go b/internal/backend/smb/smb.go new file mode 100644 index 000000000..d1dad36de --- /dev/null +++ b/internal/backend/smb/smb.go @@ -0,0 +1,565 @@ +package smb + +import ( + "context" + "hash" + "io" + "io/fs" + "os" + "path" + "path/filepath" + "strconv" + "sync" + "syscall" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/hirochachacha/go-smb2" + "github.com/restic/restic/internal/backend" + "github.com/restic/restic/internal/backend/layout" + "github.com/restic/restic/internal/backend/sema" + "github.com/restic/restic/internal/debug" + "github.com/restic/restic/internal/errors" + "github.com/restic/restic/internal/restic" + + "github.com/valyala/fastrand" +) + +// Backend stores data on an SMB endpoint. +type Backend struct { + sem sema.Semaphore + Config + layout.Layout + backend.Modes + + sessions int32 + poolMu sync.Mutex + pool []*conn + drain *time.Timer // used to drain the pool when we stop using the connections +} + +// make sure that *Backend implements backend.Backend +var _ restic.Backend = &Backend{} + +const ( + defaultLayout = "default" +) + +func open(ctx context.Context, cfg Config) (*Backend, error) { + + l, err := layout.ParseLayout(ctx, &layout.LocalFilesystem{}, cfg.Layout, defaultLayout, cfg.Path) + if err != nil { + return nil, err + } + + sem, err := sema.New(cfg.Connections) + if err != nil { + return nil, err + } + + b := &Backend{ + Config: cfg, + sem: sem, + Layout: l, + } + + debug.Log("open, config %#v", cfg) + + // set the pool drainer timer going + if b.Config.IdleTimeout != nil && *b.Config.IdleTimeout > 0 { + b.drain = time.AfterFunc(*b.Config.IdleTimeout, func() { _ = b.drainPool() }) + } + + cn, err := b.getConnection(ctx, b.ShareName) + if err != nil { + return nil, err + } + defer b.putConnection(&cn) + + stat, err := cn.smbShare.Stat(l.Filename(restic.Handle{Type: restic.ConfigFile})) + m := backend.DeriveModesFromFileInfo(stat, err) + debug.Log("using (%03O file, %03O dir) permissions", m.File, m.Dir) + + b.Modes = m + + return b, nil +} + +// Open opens the local backend as specified by config. +func Open(ctx context.Context, cfg Config) (*Backend, error) { + debug.Log("open local backend at %v (layout %q)", cfg.Path, cfg.Layout) + return open(ctx, cfg) +} + +// Create creates all the necessary files and directories for a new local +// backend at dir. Afterwards a new config blob should be created. +func Create(ctx context.Context, cfg Config) (*Backend, error) { + debug.Log("create local backend at %v (layout %q)", cfg.Path, cfg.Layout) + + b, err := open(ctx, cfg) + if err != nil { + return nil, err + } + + cn, err := b.getConnection(ctx, cfg.ShareName) + if err != nil { + return b, err + } + defer b.putConnection(&cn) + + // test if config file already exists + _, err = cn.smbShare.Lstat(b.Filename(restic.Handle{Type: restic.ConfigFile})) + if err == nil { + return nil, errors.New("config file already exists") + } + + // create paths for data and refs + for _, d := range b.Paths() { + err := cn.smbShare.MkdirAll(d, b.Modes.Dir) + if err != nil { + return nil, errors.WithStack(err) + } + } + + return b, nil +} + +func (b *Backend) Connections() uint { + return b.Config.Connections +} + +// Location returns this backend's location (the directory name). +func (b *Backend) Location() string { + return b.Join(b.ShareName, b.Path) +} + +// Hasher may return a hash function for calculating a content hash for the backend +func (b *Backend) Hasher() hash.Hash { + return nil +} + +// HasAtomicReplace returns whether Save() can atomically replace files +func (b *Backend) HasAtomicReplace() bool { + return true +} + +// IsNotExist returns true if the error is caused by a non existing file. +func (b *Backend) IsNotExist(err error) bool { + return errors.Is(err, os.ErrNotExist) +} + +// Join combines path components with slashes. +func (be *Backend) Join(p ...string) string { + return path.Join(p...) +} + +// Save stores data in the backend at the handle. +func (b *Backend) Save(ctx context.Context, h restic.Handle, rd restic.RewindReader) (err error) { + debug.Log("Save %v", h) + if err := h.Valid(); err != nil { + return backoff.Permanent(err) + } + + finalname := b.Filename(h) + dir := filepath.Dir(finalname) + + defer func() { + // Mark non-retriable errors as such + if errors.Is(err, syscall.ENOSPC) || os.IsPermission(err) { + err = backoff.Permanent(err) + } + }() + + b.sem.GetToken() + defer b.sem.ReleaseToken() + + // Create new file with a temporary name. + tmpname := filepath.Base(finalname) + "-tmp-" + + b.addSession() // Show session in use + defer b.removeSession() + + cn, err := b.getConnection(ctx, b.ShareName) + if err != nil { + return err + } + defer b.putConnection(&cn) + + f, err := b.CreateTemp(cn, dir, tmpname) + + if b.IsNotExist(err) { + debug.Log("error %v: creating dir", err) + + // error is caused by a missing directory, try to create it + mkdirErr := cn.smbShare.MkdirAll(dir, b.Modes.Dir) + if mkdirErr != nil { + debug.Log("error creating dir %v: %v", dir, mkdirErr) + } else { + // try again + f, err = b.CreateTemp(cn, dir, tmpname) + } + } + + if err != nil { + return errors.WithStack(err) + } + + defer func(f *smb2.File) { + if err != nil { + _ = f.Close() // Double Close is harmless. + // Remove after Rename is harmless: we embed the final name in the + // temporary's name and no other goroutine will get the same data to + // Save, so the temporary name should never be reused by another + // goroutine. + _ = cn.smbShare.Remove(f.Name()) + } + }(f) + + // save data, then sync + wbytes, err := io.Copy(f, rd) + if err != nil { + return errors.WithStack(err) + } + // sanity check + if wbytes != rd.Length() { + return errors.Errorf("wrote %d bytes instead of the expected %d bytes", wbytes, rd.Length()) + } + + // Ignore error if filesystem does not support fsync. + // In this case the sync call is on the smb client's file. + err = f.Sync() + syncNotSup := err != nil && (errors.Is(err, syscall.ENOTSUP)) + if err != nil && !syncNotSup { + return errors.WithStack(err) + } + + // Close, then rename. Windows doesn't like the reverse order. + if err = f.Close(); err != nil { + return errors.WithStack(err) + } + if err = cn.smbShare.Rename(f.Name(), finalname); err != nil { + return errors.WithStack(err) + } + + // try to mark file as read-only to avoid accidential modifications + // ignore if the operation fails as some filesystems don't allow the chmod call + // e.g. exfat and network file systems with certain mount options + err = cn.setFileReadonly(finalname, b.Modes.File) + if err != nil && !os.IsPermission(err) { + return errors.WithStack(err) + } + + return nil +} + +// set file to readonly +func (cn *conn) setFileReadonly(f string, mode os.FileMode) error { + return cn.smbShare.Chmod(f, mode&^0222) +} + +// Load runs fn with a reader that yields the contents of the file at h at the +// given offset. +func (b *Backend) Load(ctx context.Context, h restic.Handle, length int, offset int64, fn func(rd io.Reader) error) error { + return backend.DefaultLoad(ctx, h, length, offset, b.openReader, fn) +} + +func (b *Backend) openReader(ctx context.Context, h restic.Handle, length int, offset int64) (io.ReadCloser, error) { + debug.Log("Load %v, length %v, offset %v", h, length, offset) + if err := h.Valid(); err != nil { + return nil, backoff.Permanent(err) + } + + if offset < 0 { + return nil, errors.New("offset is negative") + } + + b.addSession() // Show session in use + defer b.removeSession() + cn, err := b.getConnection(ctx, b.ShareName) + if err != nil { + return nil, err + } + defer b.putConnection(&cn) + + b.sem.GetToken() + f, err := cn.smbShare.Open(b.Filename(h)) + if err != nil { + b.sem.ReleaseToken() + return nil, err + } + + if offset > 0 { + _, err = f.Seek(offset, 0) + if err != nil { + b.sem.ReleaseToken() + _ = f.Close() + return nil, err + } + } + + r := b.sem.ReleaseTokenOnClose(f, nil) + + if length > 0 { + return backend.LimitReadCloser(r, int64(length)), nil + } + + return r, nil +} + +// Stat returns information about a blob. +func (b *Backend) Stat(ctx context.Context, h restic.Handle) (restic.FileInfo, error) { + debug.Log("Stat %v", h) + if err := h.Valid(); err != nil { + return restic.FileInfo{}, backoff.Permanent(err) + } + + b.sem.GetToken() + defer b.sem.ReleaseToken() + + cn, err := b.getConnection(ctx, b.ShareName) + if err != nil { + return restic.FileInfo{}, err + } + defer b.putConnection(&cn) + + fi, err := cn.smbShare.Stat(b.Filename(h)) + if err != nil { + return restic.FileInfo{}, errors.WithStack(err) + } + + return restic.FileInfo{Size: fi.Size(), Name: h.Name}, nil +} + +// Remove removes the blob with the given name and type. +func (b *Backend) Remove(ctx context.Context, h restic.Handle) error { + debug.Log("Remove %v", h) + fn := b.Filename(h) + + b.sem.GetToken() + defer b.sem.ReleaseToken() + + cn, err := b.getConnection(ctx, b.ShareName) + if err != nil { + return err + } + defer b.putConnection(&cn) + + // reset read-only flag + err = cn.smbShare.Chmod(fn, 0666) + if err != nil && !os.IsPermission(err) { + return errors.WithStack(err) + } + + return cn.smbShare.Remove(fn) +} + +// List runs fn for each file in the backend which has the type t. When an +// error occurs (or fn returns an error), List stops and returns it. +func (b *Backend) List(ctx context.Context, t restic.FileType, fn func(restic.FileInfo) error) (err error) { + debug.Log("List %v", t) + + cn, err := b.getConnection(ctx, b.ShareName) + if err != nil { + return err + } + defer b.putConnection(&cn) + + basedir, subdirs := b.Basedir(t) + if subdirs { + err = b.visitDirs(cn, ctx, basedir, fn) + } else { + err = b.visitFiles(cn, ctx, basedir, fn, false) + } + + if b.IsNotExist(err) { + debug.Log("ignoring non-existing directory") + return nil + } + + return err +} + +// The following two functions are like filepath.Walk, but visit only one or +// two levels of directory structure (including dir itself as the first level). +// Also, visitDirs assumes it sees a directory full of directories, while +// visitFiles wants a directory full or regular files. +func (b *Backend) visitDirs(cn *conn, ctx context.Context, dir string, fn func(restic.FileInfo) error) error { + d, err := cn.smbShare.Open(dir) + if err != nil { + return err + } + + sub, err := d.Readdirnames(-1) + if err != nil { + // ignore subsequent errors + _ = d.Close() + return err + } + + err = d.Close() + if err != nil { + return err + } + + for _, f := range sub { + err = b.visitFiles(cn, ctx, filepath.Join(dir, f), fn, true) + if err != nil { + return err + } + } + return ctx.Err() +} + +func (b *Backend) visitFiles(cn *conn, ctx context.Context, dir string, fn func(restic.FileInfo) error, ignoreNotADirectory bool) error { + d, err := cn.smbShare.Open(dir) + if err != nil { + return err + } + + if ignoreNotADirectory { + fi, err := d.Stat() + if err != nil || !fi.IsDir() { + // ignore subsequent errors + _ = d.Close() + return err + } + } + + sub, err := d.Readdir(-1) + if err != nil { + // ignore subsequent errors + _ = d.Close() + return err + } + + err = d.Close() + if err != nil { + return err + } + + for _, fi := range sub { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + err := fn(restic.FileInfo{ + Name: fi.Name(), + Size: fi.Size(), + }) + if err != nil { + return err + } + } + return nil +} + +// Delete removes the repository and all files. +func (b *Backend) Delete(ctx context.Context) error { + debug.Log("Delete()") + cn, err := b.getConnection(ctx, b.ShareName) + if err != nil { + return err + } + defer b.putConnection(&cn) + return cn.smbShare.RemoveAll(b.Location()) +} + +// Close closes all open files. +func (b *Backend) Close() error { + debug.Log("Close()") + err := b.drainPool() + return err +} + +var ( + ErrExist = fs.ErrExist // "file already exists" +) + +// PathError records an error and the operation and file path that caused it. +type PathError = fs.PathError + +const ( + PathSeparator = '/' // OS-specific path separator + PathListSeparator = ';' // OS-specific path list separator +) + +// CreateTemp creates a new temporary file in the directory dir, +// opens the file for reading and writing, and returns the resulting file. +// The filename is generated by taking pattern and adding a random string to the end. +// If pattern includes a "*", the random string replaces the last "*". +// If dir is the empty string, CreateTemp uses the default directory for temporary files, as returned by TempDir. +// Multiple programs or goroutines calling CreateTemp simultaneously will not choose the same file. +// The caller can use the file's Name method to find the pathname of the file. +// It is the caller's responsibility to remove the file when it is no longer needed. +func (b *Backend) CreateTemp(cn *conn, dir, pattern string) (*smb2.File, error) { + if dir == "" { + dir = os.TempDir() + } + + prefix, suffix, err := prefixAndSuffix(pattern) + if err != nil { + return nil, &PathError{Op: "createtemp", Path: pattern, Err: err} + } + prefix = joinPath(dir, prefix) + + try := 0 + for { + name := prefix + nextRandom() + suffix + f, err := cn.smbShare.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0600) + + if os.IsExist(err) { + if try++; try < 10000 { + continue + } + return nil, &PathError{Op: "createtemp", Path: prefix + "*" + suffix, Err: ErrExist} + } + return f, err + } +} + +var errPatternHasSeparator = errors.New("pattern contains path separator") + +// prefixAndSuffix splits pattern by the last wildcard "*", if applicable, +// returning prefix as the part before "*" and suffix as the part after "*". +func prefixAndSuffix(pattern string) (prefix, suffix string, err error) { + for i := 0; i < len(pattern); i++ { + if IsPathSeparator(pattern[i]) { + return "", "", errPatternHasSeparator + } + } + if pos := lastIndex(pattern, '*'); pos != -1 { + prefix, suffix = pattern[:pos], pattern[pos+1:] + } else { + prefix = pattern + } + return prefix, suffix, nil +} + +// LastIndexByte from the strings package. +func lastIndex(s string, sep byte) int { + for i := len(s) - 1; i >= 0; i-- { + if s[i] == sep { + return i + } + } + return -1 +} + +func nextRandom() string { + return strconv.FormatUint(uint64(fastrand.Uint32()), 10) +} + +func joinPath(dir, name string) string { + if len(dir) > 0 && IsPathSeparator(dir[len(dir)-1]) { + return dir + name + } + return dir + string(PathSeparator) + name +} + +// IsPathSeparator reports whether c is a directory separator character. +func IsPathSeparator(c uint8) bool { + // NOTE: Windows accepts / as path separator. + return c == '\\' || c == '/' +} diff --git a/internal/backend/smb/smb_test.go b/internal/backend/smb/smb_test.go new file mode 100644 index 000000000..86a32b785 --- /dev/null +++ b/internal/backend/smb/smb_test.go @@ -0,0 +1,77 @@ +package smb_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + "github.com/restic/restic/internal/backend/smb" + "github.com/restic/restic/internal/backend/test" + "github.com/restic/restic/internal/options" + "github.com/restic/restic/internal/restic" + rtest "github.com/restic/restic/internal/test" +) + +func newTestSuite(t testing.TB) *test.Suite { + return &test.Suite{ + // NewConfig returns a config for a new temporary backend that will be used in tests. + NewConfig: func() (interface{}, error) { + + cfg := smb.NewConfig() + cfg.Address = "127.0.0.1" + cfg.User = "smbuser" + cfg.ShareName = cfg.User + cfg.Path = "Repo-" + uuid.New().String() + cfg.Password = options.NewSecretString("mGoWwqvgdnwtmh07") + cfg.Connections = smb.DefaultConnections + timeout := smb.DefaultIdleTimeout + cfg.IdleTimeout = &timeout + cfg.Domain = smb.DefaultDomain + + t.Logf("create new backend at %v", cfg.Address+"/"+cfg.ShareName) + + return cfg, nil + }, + + // CreateFn is a function that creates a temporary repository for the tests. + Create: func(config interface{}) (restic.Backend, error) { + cfg := config.(smb.Config) + return smb.Create(context.TODO(), cfg) + }, + + // OpenFn is a function that opens a previously created temporary repository. + Open: func(config interface{}) (restic.Backend, error) { + cfg := config.(smb.Config) + return smb.Open(context.TODO(), cfg) + }, + + // CleanupFn removes data created during the tests. + Cleanup: func(config interface{}) error { + cfg := config.(smb.Config) + if !rtest.TestCleanupTempDirs { + t.Logf("leaving test backend dir at %v", cfg.Path) + } + + rtest.RemoveAll(t, cfg.Path) + return nil + }, + } +} + +func TestBackendSMB(t *testing.T) { + if !rtest.RunSMBTest { + t.Skip("Skipping smb tests") + } + t.Logf("run tests") + + newTestSuite(t).RunTests(t) +} + +func BenchmarkBackendSMB(t *testing.B) { + if !rtest.RunSMBTest { + t.Skip("Skipping smb tests") + } + t.Logf("run benchmarks") + + newTestSuite(t).RunBenchmarks(t) +} diff --git a/internal/test/vars.go b/internal/test/vars.go index b6b76541e..e2b2497ef 100644 --- a/internal/test/vars.go +++ b/internal/test/vars.go @@ -13,6 +13,7 @@ var ( TestTempDir = getStringVar("RESTIC_TEST_TMPDIR", "") RunIntegrationTest = getBoolVar("RESTIC_TEST_INTEGRATION", true) RunFuseTest = getBoolVar("RESTIC_TEST_FUSE", true) + RunSMBTest = getBoolVar("RESTIC_TEST_SMB", true) TestSFTPPath = getStringVar("RESTIC_TEST_SFTPPATH", "/usr/lib/ssh:/usr/lib/openssh:/usr/libexec") TestWalkerPath = getStringVar("RESTIC_TEST_PATH", ".") BenchArchiveDirectory = getStringVar("RESTIC_BENCH_DIR", ".")