From a9b08cd5d015ba26c6d79a9a6384613021b203f2 Mon Sep 17 00:00:00 2001 From: aneesh-n <99904+aneesh-n@users.noreply.github.com> Date: Sun, 1 Sep 2024 22:52:30 -0600 Subject: [PATCH] Fix merge issues and update smb code similar to local --- internal/backend/smb/conpool.go | 40 +++++----- internal/backend/smb/smb.go | 131 +++++++++++++++++--------------- 2 files changed, 90 insertions(+), 81 deletions(-) diff --git a/internal/backend/smb/conpool.go b/internal/backend/smb/conpool.go index 14379cc0c..a38923cf2 100644 --- a/internal/backend/smb/conpool.go +++ b/internal/backend/smb/conpool.go @@ -2,6 +2,7 @@ package smb import ( "context" + "errors" "fmt" "net" "strconv" @@ -40,16 +41,15 @@ type conn struct { shareName string } -// Closes the connection -func (c *conn) close() (err error) { +func (c *conn) close() error { + var err, errLogoff error if c.smbShare != nil { err = c.smbShare.Umount() } - sessionLogoffErr := c.smbSession.Logoff() - if err != nil { - return err + if c.smbSession != nil { + errLogoff = c.smbSession.Logoff() } - return sessionLogoffErr + return errors.Join(err, errLogoff) } // True if it's closed @@ -68,24 +68,24 @@ func (c *conn) closed() bool { // Show that we are using a SMB session // // Call removeSession() when done -func (b *Backend) addSession() { +func (b *SMB) addSession() { atomic.AddInt32(&b.sessions, 1) } // Show the SMB session is no longer in use -func (b *Backend) removeSession() { +func (b *SMB) removeSession() { atomic.AddInt32(&b.sessions, -1) } // getSessions shows whether there are any sessions in use -func (b *Backend) getSessions() int32 { +func (b *SMB) getSessions() int32 { return atomic.LoadInt32(&b.sessions) } // dial starts a client connection to the given SMB server. It is a // convenience function that connects to the given network address, // initiates the SMB handshake, and then returns a session for SMB communication. -func (b *Backend) dial(ctx context.Context, network, addr string) (*conn, error) { +func (b *SMB) dial(ctx context.Context, network, addr string) (*conn, error) { dialer := net.Dialer{} tconn, err := dialer.Dial(network, addr) if err != nil { @@ -121,7 +121,7 @@ func (b *Backend) dial(ctx context.Context, network, addr string) (*conn, error) } // Open a new connection to the SMB server. -func (b *Backend) newConnection(share string) (c *conn, err error) { +func (b *SMB) newConnection(share string) (c *conn, err error) { // As we are pooling these connections we need to decouple // them from the current context ctx := context.Background() @@ -167,7 +167,7 @@ func (c *conn) mountShare(share string) (err error) { } // Get a SMB connection from the pool, or open a new one -func (b *Backend) getConnection(_ context.Context, share string) (c *conn, err error) { +func (b *SMB) getConnection(share string) (c *conn, err error) { b.poolMu.Lock() for len(b.pool) > 0 { c = b.pool[0] @@ -187,8 +187,12 @@ func (b *Backend) getConnection(_ context.Context, share string) (c *conn, err e return c, err } -// Return a SMB connection to the pool -func (b *Backend) putConnection(c *conn) { +// put the connection back into the connection pool for reuse +func (b *SMB) putConnection(c *conn) { + if c == nil { + return + } + var nopErr error if c.smbShare != nil { // stat the current directory @@ -205,20 +209,20 @@ func (b *Backend) putConnection(c *conn) { b.poolMu.Lock() b.pool = append(b.pool, c) - b.drain.Reset(b.Config.IdleTimeout) // nudge on the pool emptying timer + b.drain.Reset(b.IdleTimeout) b.poolMu.Unlock() } // Drain the pool of any connections -func (b *Backend) drainPool() (err error) { +func (b *SMB) drainPool() (err error) { b.poolMu.Lock() defer b.poolMu.Unlock() if sessions := b.getSessions(); sessions != 0 { debug.Log("Not closing %d unused connections as %d sessions active", len(b.pool), sessions) - b.drain.Reset(b.Config.IdleTimeout) // nudge on the pool emptying timer + b.drain.Reset(b.IdleTimeout) // nudge on the pool emptying timer return nil } - if b.Config.IdleTimeout > 0 { + if b.IdleTimeout > 0 { b.drain.Stop() } if len(b.pool) != 0 { diff --git a/internal/backend/smb/smb.go b/internal/backend/smb/smb.go index 367069c59..b5f00751f 100644 --- a/internal/backend/smb/smb.go +++ b/internal/backend/smb/smb.go @@ -22,7 +22,6 @@ import ( "github.com/restic/restic/internal/backend/util" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/errors" - "github.com/restic/restic/internal/restic" ) // Parts of this code have been adapted from Rclone (https://github.com/rclone) @@ -46,37 +45,31 @@ import ( // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -// Backend stores data on an SMB endpoint. -type Backend struct { - Config - layout.Layout - util.Modes - +// SMB is a backend which stores the data on an SMB share. +type SMB struct { sessions int32 poolMu sync.Mutex pool []*conn drain *time.Timer // used to drain the pool when we stop using the connections + + layout.Layout + Config + util.Modes } -// make sure that *Backend implements backend.Backend -var _ backend.Backend = &Backend{} +// ensure statically that *SMB implements backend.Backend. +var _ backend.Backend = &SMB{} + +var errTooShort = errors.New("file is too short") func NewFactory() location.Factory { return location.NewLimitedBackendFactory("smb", ParseConfig, location.NoPassword, limiter.WrapBackendConstructor(Create), limiter.WrapBackendConstructor(Open)) } -const ( - defaultLayout = "default" -) +func open(cfg Config) (*SMB, error) { + l := layout.NewDefaultLayout(cfg.Path, filepath.Join) -func open(ctx context.Context, cfg Config) (*Backend, error) { - - l, err := layout.ParseLayout(ctx, &layout.LocalFilesystem{}, cfg.Layout, defaultLayout, cfg.Path) - if err != nil { - return nil, err - } - - b := &Backend{ + b := &SMB{ Config: cfg, Layout: l, } @@ -88,13 +81,13 @@ func open(ctx context.Context, cfg Config) (*Backend, error) { b.drain = time.AfterFunc(b.Config.IdleTimeout, func() { _ = b.drainPool() }) } - cn, err := b.getConnection(ctx, b.ShareName) + cn, err := b.getConnection(b.ShareName) if err != nil { return nil, err } defer b.putConnection(cn) - stat, err := cn.smbShare.Stat(l.Filename(backend.Handle{Type: restic.ConfigFile})) + stat, err := cn.smbShare.Stat(l.Filename(backend.Handle{Type: backend.ConfigFile})) m := util.DeriveModesFromFileInfo(stat, err) debug.Log("using (%03O file, %03O dir) permissions", m.File, m.Dir) @@ -103,30 +96,29 @@ func open(ctx context.Context, cfg Config) (*Backend, error) { return b, nil } -// Open opens the local backend as specified by config. -func Open(ctx context.Context, cfg Config) (*Backend, error) { - debug.Log("open local backend at %v (layout %q)", cfg.Path, cfg.Layout) - return open(ctx, cfg) +// Open opens the SMB backend as specified by the config. +func Open(_ context.Context, cfg Config) (*SMB, error) { + debug.Log("open smb backend at %v (share %q)", cfg.Path, cfg.ShareName) + return open(cfg) } -// Create creates all the necessary files and directories for a new local -// backend at dir. Afterwards a new config blob should be created. -func Create(ctx context.Context, cfg Config) (*Backend, error) { - debug.Log("create local backend at %v (layout %q)", cfg.Path, cfg.Layout) +// Create creates all the necessary files and directories for a new SMB backend. +func Create(_ context.Context, cfg Config) (*SMB, error) { + debug.Log("create smb backend at %v (share %q)", cfg.Path, cfg.ShareName) - b, err := open(ctx, cfg) + b, err := open(cfg) if err != nil { return nil, err } - cn, err := b.getConnection(ctx, cfg.ShareName) + cn, err := b.getConnection(cfg.ShareName) if err != nil { return b, err } defer b.putConnection(cn) // test if config file already exists - _, err = cn.smbShare.Lstat(b.Filename(backend.Handle{Type: restic.ConfigFile})) + _, err = cn.smbShare.Lstat(b.Filename(backend.Handle{Type: backend.ConfigFile})) if err == nil { return nil, errors.New("config file already exists") } @@ -142,37 +134,36 @@ func Create(ctx context.Context, cfg Config) (*Backend, error) { return b, nil } -func (b *Backend) Connections() uint { +func (b *SMB) Connections() uint { return b.Config.Connections } -// Location returns this backend's location (the directory name). -func (b *Backend) Location() string { - return b.Join(b.ShareName, b.Path) -} - // Hasher may return a hash function for calculating a content hash for the backend -func (b *Backend) Hasher() hash.Hash { +func (b *SMB) Hasher() hash.Hash { return nil } // HasAtomicReplace returns whether Save() can atomically replace files -func (b *Backend) HasAtomicReplace() bool { +func (b *SMB) HasAtomicReplace() bool { return true } // IsNotExist returns true if the error is caused by a non existing file. -func (b *Backend) IsNotExist(err error) bool { +func (b *SMB) IsNotExist(err error) bool { return errors.Is(err, os.ErrNotExist) } // Join combines path components with slashes. -func (b *Backend) Join(p ...string) string { +func (b *SMB) Join(p ...string) string { return path.Join(p...) } +func (b *SMB) IsPermanentError(err error) bool { + return b.IsNotExist(err) || errors.Is(err, errTooShort) || errors.Is(err, os.ErrPermission) +} + // Save stores data in the backend at the handle. -func (b *Backend) Save(ctx context.Context, h backend.Handle, rd backend.RewindReader) (err error) { +func (b *SMB) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) (err error) { filename := b.Filename(h) tmpFilename := filename + "-restic-temp-" + tempSuffix() dir := filepath.Dir(tmpFilename) @@ -187,14 +178,16 @@ func (b *Backend) Save(ctx context.Context, h backend.Handle, rd backend.RewindR b.addSession() // Show session in use defer b.removeSession() - cn, err := b.getConnection(ctx, b.ShareName) + var cn *conn + cn, err = b.getConnection(b.ShareName) if err != nil { return err } defer b.putConnection(cn) + var f *smb2.File // create new file - f, err := cn.smbShare.OpenFile(tmpFilename, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) + f, err = cn.smbShare.OpenFile(tmpFilename, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) if b.IsNotExist(err) { debug.Log("error %v: creating dir", err) @@ -250,7 +243,7 @@ func (b *Backend) Save(ctx context.Context, h backend.Handle, rd backend.RewindR return errors.WithStack(err) } - // try to mark file as read-only to avoid accidential modifications + // try to mark file as read-only to avoid accidental modifications // ignore if the operation fails as some filesystems don't allow the chmod call // e.g. exfat and network file systems with certain mount options err = cn.setFileReadonly(filename, b.Modes.File) @@ -268,14 +261,14 @@ func (cn *conn) setFileReadonly(f string, mode os.FileMode) error { // Load runs fn with a reader that yields the contents of the file at h at the // given offset. -func (b *Backend) Load(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { +func (b *SMB) Load(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { return util.DefaultLoad(ctx, h, length, offset, b.openReader, fn) } -func (b *Backend) openReader(ctx context.Context, h backend.Handle, length int, offset int64) (io.ReadCloser, error) { +func (b *SMB) openReader(_ context.Context, h backend.Handle, length int, offset int64) (io.ReadCloser, error) { b.addSession() // Show session in use defer b.removeSession() - cn, err := b.getConnection(ctx, b.ShareName) + cn, err := b.getConnection(b.ShareName) if err != nil { return nil, err } @@ -286,6 +279,18 @@ func (b *Backend) openReader(ctx context.Context, h backend.Handle, length int, return nil, err } + fi, err := f.Stat() + if err != nil { + _ = f.Close() + return nil, err + } + + size := fi.Size() + if size < offset+int64(length) { + _ = f.Close() + return nil, errTooShort + } + if offset > 0 { _, err = f.Seek(offset, 0) if err != nil { @@ -295,15 +300,15 @@ func (b *Backend) openReader(ctx context.Context, h backend.Handle, length int, } if length > 0 { - return backend.LimitReadCloser(f, int64(length)), nil + return util.LimitReadCloser(f, int64(length)), nil } return f, nil } // Stat returns information about a blob. -func (b *Backend) Stat(ctx context.Context, h backend.Handle) (backend.FileInfo, error) { - cn, err := b.getConnection(ctx, b.ShareName) +func (b *SMB) Stat(_ context.Context, h backend.Handle) (backend.FileInfo, error) { + cn, err := b.getConnection(b.ShareName) if err != nil { return backend.FileInfo{}, err } @@ -318,10 +323,10 @@ func (b *Backend) Stat(ctx context.Context, h backend.Handle) (backend.FileInfo, } // Remove removes the blob with the given name and type. -func (b *Backend) Remove(ctx context.Context, h backend.Handle) error { +func (b *SMB) Remove(_ context.Context, h backend.Handle) error { fn := b.Filename(h) - cn, err := b.getConnection(ctx, b.ShareName) + cn, err := b.getConnection(b.ShareName) if err != nil { return err } @@ -338,8 +343,8 @@ func (b *Backend) Remove(ctx context.Context, h backend.Handle) error { // List runs fn for each file in the backend which has the type t. When an // error occurs (or fn returns an error), List stops and returns it. -func (b *Backend) List(ctx context.Context, t restic.FileType, fn func(backend.FileInfo) error) (err error) { - cn, err := b.getConnection(ctx, b.ShareName) +func (b *SMB) List(ctx context.Context, t backend.FileType, fn func(backend.FileInfo) error) error { + cn, err := b.getConnection(b.ShareName) if err != nil { return err } @@ -364,7 +369,7 @@ func (b *Backend) List(ctx context.Context, t restic.FileType, fn func(backend.F // two levels of directory structure (including dir itself as the first level). // Also, visitDirs assumes it sees a directory full of directories, while // visitFiles wants a directory full or regular files. -func (b *Backend) visitDirs(ctx context.Context, cn *conn, dir string, fn func(backend.FileInfo) error) error { +func (b *SMB) visitDirs(ctx context.Context, cn *conn, dir string, fn func(backend.FileInfo) error) error { d, err := cn.smbShare.Open(dir) if err != nil { return err @@ -391,7 +396,7 @@ func (b *Backend) visitDirs(ctx context.Context, cn *conn, dir string, fn func(b return ctx.Err() } -func (b *Backend) visitFiles(ctx context.Context, cn *conn, dir string, fn func(backend.FileInfo) error, ignoreNotADirectory bool) error { +func (b *SMB) visitFiles(ctx context.Context, cn *conn, dir string, fn func(backend.FileInfo) error, ignoreNotADirectory bool) error { d, err := cn.smbShare.Open(dir) if err != nil { return err @@ -437,17 +442,17 @@ func (b *Backend) visitFiles(ctx context.Context, cn *conn, dir string, fn func( } // Delete removes the repository and all files. -func (b *Backend) Delete(ctx context.Context) error { - cn, err := b.getConnection(ctx, b.ShareName) +func (b *SMB) Delete(_ context.Context) error { + cn, err := b.getConnection(b.ShareName) if err != nil { return err } defer b.putConnection(cn) - return cn.smbShare.RemoveAll(b.Location()) + return cn.smbShare.RemoveAll(b.Join(b.ShareName, b.Path)) } // Close closes all open files. -func (b *Backend) Close() error { +func (b *SMB) Close() error { err := b.drainPool() return err }