diff --git a/doc/030_preparing_a_new_repo.rst b/doc/030_preparing_a_new_repo.rst index 69a8f9739..8152b8b31 100644 --- a/doc/030_preparing_a_new_repo.rst +++ b/doc/030_preparing_a_new_repo.rst @@ -639,7 +639,7 @@ You must first setup the following environment variables with the SMB credential $ export RESTIC_SMB_DOMAIN= -Once the server is configured, the setup of the SFTP repository can +Once the server is configured, the setup of the SMB repository can simply be achieved by changing the URL scheme in the ``init`` command: .. code-block:: console diff --git a/internal/backend/local/local.go b/internal/backend/local/local.go index a953dcae0..fbd7e6db8 100644 --- a/internal/backend/local/local.go +++ b/internal/backend/local/local.go @@ -15,16 +15,17 @@ import ( "github.com/restic/restic/internal/debug" ) -// Local is a backend in a local directory. +// Local represents a backend in a local directory. type Local struct { Config layout.Layout util.Modes } -// ensure statically that *Local implements backend.Backend. +// Ensure statically that *Local implements backend.Backend. var _ backend.Backend = &Local{} +// NewFactory returns a new factory for local backends. func NewFactory() location.Factory { return location.NewLimitedBackendFactory("local", ParseConfig, location.NoPassword, limiter.WrapBackendConstructor(Create), limiter.WrapBackendConstructor(Open)) } @@ -60,31 +61,33 @@ func Create(_ context.Context, cfg Config) (*Local, error) { return be, nil } +// Connections returns the number of configured connections. func (b *Local) Connections() uint { return b.Config.Connections } -// Hasher may return a hash function for calculating a content hash for the backend +// Hasher returns a hash function for calculating a content hash for the backend. func (b *Local) Hasher() hash.Hash { return nil } -// HasAtomicReplace returns whether Save() can atomically replace files +// HasAtomicReplace returns whether Save() can atomically replace files. func (b *Local) HasAtomicReplace() bool { return true } -// IsNotExist returns true if the error is caused by a non existing file. +// IsNotExist returns true if the error is caused by a non-existing file. func (b *Local) IsNotExist(err error) bool { return util.IsNotExist(err) } +// IsPermanentError returns true if the error is permanent. func (b *Local) IsPermanentError(err error) bool { return util.IsPermanentError(err) } // Save stores data in the backend at the handle. -func (b *Local) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) (err error) { +func (b *Local) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) error { fileName := b.Filename(h) // Create new file with a temporary name. tmpFilename := filepath.Base(fileName) + "-tmp-" @@ -137,7 +140,7 @@ func (b *Local) Remove(_ context.Context, h backend.Handle) error { // List runs fn for each file in the backend which has the type t. When an // error occurs (or fn returns an error), List stops and returns it. -func (b *Local) List(ctx context.Context, t backend.FileType, fn func(backend.FileInfo) error) (err error) { +func (b *Local) List(ctx context.Context, t backend.FileType, fn func(backend.FileInfo) error) error { openFunc := func(name string) (util.File, error) { return os.Open(name) } @@ -152,7 +155,7 @@ func (b *Local) Delete(_ context.Context) error { // Close closes all open files. func (b *Local) Close() error { - // this does not need to do anything, all open files are closed within the + // This does not need to do anything, all open files are closed within the // same function. return nil } diff --git a/internal/backend/smb/config.go b/internal/backend/smb/config.go index 0d133c2ba..55f513a81 100644 --- a/internal/backend/smb/config.go +++ b/internal/backend/smb/config.go @@ -12,6 +12,13 @@ import ( "github.com/restic/restic/internal/options" ) +const ( + DefaultSMBPort int = 445 // DefaultSMBt returns the default port for SMB + DefaultDomain string = "WORKGROUP" // DefaultDomain returns the default domain for SMB + DefaultConnections uint = 5 // DefaultConnections returns the number of concurrent connections for SMB. + DefaultIdleTimeout time.Duration = 60 * time.Second // DefaultIdleTimeout returns the default max time before closing idle connections for SMB. +) + // Config contains all configuration necessary to connect to an SMB server type Config struct { Host string @@ -31,17 +38,10 @@ type Config struct { ClientGUID string `option:"client-guid" help:"A 16-byte GUID to uniquely identify a client. If not specific a random GUID is used. (default: \"\")"` } -const ( - DefaultSmbPort int = 445 // DefaultSmbPort returns the default port for SMB - DefaultDomain string = "WORKGROUP" // DefaultDomain returns the default domain for SMB - DefaultConnections uint = 5 // DefaultConnections returns the number of concurrent connections for SMB. - DefaultIdleTimeout time.Duration = 60 * time.Second // DefaultIdleTimeout returns the default max time before closing idle connections for SMB. -) - // NewConfig returns a new Config with the default values filled in. func NewConfig() Config { return Config{ - Port: DefaultSmbPort, + Port: DefaultSMBPort, Domain: DefaultDomain, IdleTimeout: DefaultIdleTimeout, Connections: DefaultConnections, @@ -52,76 +52,52 @@ func init() { options.Register("smb", Config{}) } -// ParseConfig parses the string s and extracts the s3 config. The +// ParseConfig parses the string s and extracts the SMB config. The // supported configuration format is smb://[user@]host[:port]/sharename/directory. // User and port are optional. Default port is 445. func ParseConfig(s string) (*Config, error) { - var repo string - switch { - case strings.HasPrefix(s, "smb://"): - repo = s - case strings.HasPrefix(s, "smb:"): - repo = "smb://" + s[4:] - default: + hasSmbColonPrefix := strings.HasPrefix(s, "smb:") + if !hasSmbColonPrefix && !strings.HasPrefix(s, "smb://") { return nil, errors.New("smb: invalid format") } - var user, host, port, dir string + if hasSmbColonPrefix { + s = "smb://" + s[4:] + } // parse the "smb://user@host/sharename/directory." url format - url, err := url.Parse(repo) + u, err := url.Parse(s) if err != nil { return nil, errors.WithStack(err) } - if url.User != nil { - user = url.User.Username() + + cfg := NewConfig() + + if u.User != nil { + cfg.User = u.User.Username() //Intentionally not allowing passwords to be set in url as //it can cause issues when passwords have special characters //like '@' and it is not recommended to pass passwords in the url. } - host = url.Hostname() - if host == "" { + cfg.Host = u.Hostname() + if cfg.Host == "" { return nil, errors.New("smb: invalid format, host name not found") } - port = url.Port() - dir = url.Path - if dir == "" { - return nil, errors.Errorf("smb: invalid format, sharename/directory not found") - } - dir = dir[1:] - - var portNum int - if port == "" { - portNum = DefaultSmbPort - } else { - var err error - portNum, err = strconv.Atoi(port) + if u.Port() != "" { + cfg.Port, err = strconv.Atoi(u.Port()) if err != nil { - return nil, err + return nil, errors.Wrap(err, "invalid port number") } } - sharename, directory, _ := strings.Cut(dir, "/") - - return createConfig(user, host, portNum, sharename, directory) -} - -func createConfig(user string, host string, port int, sharename, directory string) (*Config, error) { - if host == "" { - return nil, errors.New("smb: invalid format, Host not found") + if u.Path == "" { + return nil, errors.New("smb: invalid format, sharename/directory not found") } - if directory != "" { - directory = path.Clean(directory) - } + cfg.ShareName, cfg.Path, _ = strings.Cut(u.Path[1:], "/") + cfg.Path = path.Clean(cfg.Path) - cfg := NewConfig() - cfg.User = user - cfg.Host = host - cfg.Port = port - cfg.ShareName = sharename - cfg.Path = directory return &cfg, nil } @@ -135,9 +111,9 @@ func (cfg *Config) ApplyEnvironment(prefix string) error { } if cfg.Domain == "" { cfg.Domain = os.Getenv(prefix + "RESTIC_SMB_DOMAIN") - } - if cfg.Domain == "" { - cfg.Domain = DefaultDomain + if cfg.Domain == "" { + cfg.Domain = DefaultDomain + } } return nil } diff --git a/internal/backend/smb/config_test.go b/internal/backend/smb/config_test.go index 9b3e1dee2..85700ff80 100644 --- a/internal/backend/smb/config_test.go +++ b/internal/backend/smb/config_test.go @@ -11,7 +11,7 @@ var configTests = []test.ConfigTestData[Config]{ {S: "smb://user@host/sharename/directory", Cfg: Config{ Host: "host", - Port: DefaultSmbPort, + Port: DefaultSMBPort, User: "user", Domain: DefaultDomain, ShareName: "sharename", @@ -33,7 +33,7 @@ var configTests = []test.ConfigTestData[Config]{ {S: "smb://host/sharename/directory", Cfg: Config{ Host: "host", - Port: DefaultSmbPort, + Port: DefaultSMBPort, Domain: DefaultDomain, ShareName: "sharename", Path: "directory", diff --git a/internal/backend/smb/conpool.go b/internal/backend/smb/conpool.go index a38923cf2..3720933bf 100644 --- a/internal/backend/smb/conpool.go +++ b/internal/backend/smb/conpool.go @@ -33,52 +33,50 @@ import ( // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -// conn encapsulates a SMB client and corresponding SMB client +// conn encapsulates a SMB client and corresponding SMB session and share. type conn struct { - conn *net.Conn + netConn net.Conn smbSession *smb2.Session smbShare *smb2.Share shareName string } func (c *conn) close() error { - var err, errLogoff error + var errs []error if c.smbShare != nil { - err = c.smbShare.Umount() + errs = append(errs, c.smbShare.Umount()) } if c.smbSession != nil { - errLogoff = c.smbSession.Logoff() + errs = append(errs, c.smbSession.Logoff()) } - return errors.Join(err, errLogoff) + return errors.Join(errs...) } -// True if it's closed -func (c *conn) closed() bool { - var nopErr error +// isClosed checks if the connection is closed. +func (c *conn) isClosed() bool { if c.smbShare != nil { // stat the current directory - _, nopErr = c.smbShare.Stat(".") - } else { - // list the shares - _, nopErr = c.smbSession.ListSharenames() + _, err := c.smbShare.Stat(".") + return err != nil } - return nopErr == nil + // list the shares + _, err := c.smbSession.ListSharenames() + return err != nil } -// Show that we are using a SMB session -// -// Call removeSession() when done +// addSession increments the active session count when an SMB session needs to be used. +// If this is called, we must call removeSession when we are done using the session. func (b *SMB) addSession() { atomic.AddInt32(&b.sessions, 1) } -// Show the SMB session is no longer in use +// removeSession decrements the active session count when it is no longer in use. func (b *SMB) removeSession() { atomic.AddInt32(&b.sessions, -1) } -// getSessions shows whether there are any sessions in use -func (b *SMB) getSessions() int32 { +// getSessionCount returns the number of active sessions. +func (b *SMB) getSessionCount() int32 { return atomic.LoadInt32(&b.sessions) } @@ -87,15 +85,13 @@ func (b *SMB) getSessions() int32 { // initiates the SMB handshake, and then returns a session for SMB communication. func (b *SMB) dial(ctx context.Context, network, addr string) (*conn, error) { dialer := net.Dialer{} - tconn, err := dialer.Dial(network, addr) + netConn, err := dialer.DialContext(ctx, network, addr) if err != nil { - return nil, err - } - var clientID [16]byte - if b.ClientGUID != "" { - copy(clientID[:], []byte(b.ClientGUID)) + return nil, fmt.Errorf("SMB dial failed: %w", err) } + clientID := b.getClientID() + d := &smb2.Dialer{ Negotiator: smb2.Negotiator{ RequireMessageSigning: b.RequireMessageSigning, @@ -109,26 +105,35 @@ func (b *SMB) dial(ctx context.Context, network, addr string) (*conn, error) { }, } - session, err := d.DialContext(ctx, tconn) + session, err := d.DialContext(ctx, netConn) if err != nil { - return nil, err + return nil, fmt.Errorf("SMB session initialization failed: %w", err) } return &conn{ + netConn: netConn, smbSession: session, - conn: &tconn, }, nil } -// Open a new connection to the SMB server. -func (b *SMB) newConnection(share string) (c *conn, err error) { +// getClientID returns the client GUID. +func (b *SMB) getClientID() [16]byte { + var clientID [16]byte + if b.ClientGUID != "" { + copy(clientID[:], []byte(b.ClientGUID)) + } + return clientID +} + +// newConnection creates a new SMB connection. +func (b *SMB) newConnection(share string) (*conn, error) { // As we are pooling these connections we need to decouple // them from the current context ctx := context.Background() - c, err = b.dial(ctx, "tcp", b.Host+":"+strconv.Itoa(b.Port)) + c, err := b.dial(ctx, "tcp", net.JoinHostPort(b.Host, strconv.Itoa(b.Port))) if err != nil { - return nil, fmt.Errorf("couldn't connect SMB: %w", err) + return nil, fmt.Errorf("SMB connection failed: %w", err) } if share != "" { @@ -136,7 +141,7 @@ func (b *SMB) newConnection(share string) (c *conn, err error) { c.smbShare, err = c.smbSession.Mount(share) if err != nil { _ = c.smbSession.Logoff() - return nil, fmt.Errorf("couldn't initialize SMB: %w", err) + return nil, fmt.Errorf("SMB share mount failed: %w", err) } c.smbShare = c.smbShare.WithContext(ctx) } @@ -144,29 +149,31 @@ func (b *SMB) newConnection(share string) (c *conn, err error) { return c, nil } -// Ensure the specified share is mounted or the session is unmounted -func (c *conn) mountShare(share string) (err error) { +// mountShare ensures the existing share is unmounted and the specified share is mounted. +func (c *conn) mountShare(share string) error { if c.shareName == share { return nil } if c.smbShare != nil { - err = c.smbShare.Umount() + if err := c.smbShare.Umount(); err != nil { + // Check if we should not nil out the share for some errors + c.smbShare = nil + return err + } c.smbShare = nil } - if err != nil { - return - } if share != "" { + var err error c.smbShare, err = c.smbSession.Mount(share) if err != nil { - return + return err } } c.shareName = share return nil } -// Get a SMB connection from the pool, or open a new one +// getConnection retrieves a connection from the pool or creates a new one. func (b *SMB) getConnection(share string) (c *conn, err error) { b.poolMu.Lock() for len(b.pool) > 0 { @@ -183,60 +190,54 @@ func (b *SMB) getConnection(share string) (c *conn, err error) { if c != nil { return c, nil } - c, err = b.newConnection(share) - return c, err + return b.newConnection(share) } -// put the connection back into the connection pool for reuse +// putConnection returns a connection to the pool for reuse. func (b *SMB) putConnection(c *conn) { if c == nil { return } - var nopErr error - if c.smbShare != nil { - // stat the current directory - _, nopErr = c.smbShare.Stat(".") - } else { - // list the shares - _, nopErr = c.smbSession.ListSharenames() - } - if nopErr != nil { - debug.Log("Connection failed, closing: %v", nopErr) + if c.isClosed() { + debug.Log("Connection closed, not returning to pool") _ = c.close() return } b.poolMu.Lock() + defer b.poolMu.Unlock() + b.pool = append(b.pool, c) b.drain.Reset(b.IdleTimeout) - b.poolMu.Unlock() } -// Drain the pool of any connections -func (b *SMB) drainPool() (err error) { +// drainPool closes all unused connections in the pool. +func (b *SMB) drainPool() error { b.poolMu.Lock() defer b.poolMu.Unlock() - if sessions := b.getSessions(); sessions != 0 { + + if sessions := b.getSessionCount(); sessions != 0 { debug.Log("Not closing %d unused connections as %d sessions active", len(b.pool), sessions) - b.drain.Reset(b.IdleTimeout) // nudge on the pool emptying timer + b.drain.Reset(b.IdleTimeout) // reset the timer to keep the pool open return nil } + if b.IdleTimeout > 0 { b.drain.Stop() } if len(b.pool) != 0 { - debug.Log("Closing %d unused connections", len(b.pool)) + debug.Log("Attempting to close %d unused connections", len(b.pool)) } - for i, c := range b.pool { - if !c.closed() { - cErr := c.close() - if cErr != nil { - err = cErr + var errs []error + for _, c := range b.pool { + if !c.isClosed() { + if err := c.close(); err != nil { + errs = append(errs, err) } } - b.pool[i] = nil } b.pool = nil - return err + + return errors.Join(errs...) } diff --git a/internal/backend/smb/smb.go b/internal/backend/smb/smb.go index 517941dfb..299bcd1a3 100644 --- a/internal/backend/smb/smb.go +++ b/internal/backend/smb/smb.go @@ -18,6 +18,7 @@ import ( "github.com/restic/restic/internal/backend/location" "github.com/restic/restic/internal/backend/util" "github.com/restic/restic/internal/debug" + "github.com/restic/restic/internal/errors" ) // Parts of this code have been adapted from Rclone (https://github.com/rclone) @@ -53,13 +54,15 @@ type SMB struct { util.Modes } -// ensure statically that *SMB implements backend.Backend. +// Ensure statically that *SMB implements backend.Backend interface. var _ backend.Backend = &SMB{} +// NewFactory returns a new SMB backend factory. func NewFactory() location.Factory { return location.NewLimitedBackendFactory("smb", ParseConfig, location.NoPassword, limiter.WrapBackendConstructor(Create), limiter.WrapBackendConstructor(Open)) } +// open initializes a new SMB backend. func open(cfg Config) (*SMB, error) { l := layout.NewDefaultLayout(cfg.Path, filepath.Join) @@ -111,36 +114,37 @@ func Create(_ context.Context, cfg Config) (*SMB, error) { return b, nil } +// Connections returns the number of configured connections. func (b *SMB) Connections() uint { return b.Config.Connections } -// Hasher may return a hash function for calculating a content hash for the backend +// Hasher returns a hash function for calculating a content hash for the backend. func (b *SMB) Hasher() hash.Hash { return nil } -// HasAtomicReplace returns whether Save() can atomically replace files +// HasAtomicReplace returns whether Save() can atomically replace files. func (b *SMB) HasAtomicReplace() bool { return true } -// IsNotExist returns true if the error is caused by a non existing file. +// IsNotExist returns true if the error is caused by a non-existing file. func (b *SMB) IsNotExist(err error) bool { return util.IsNotExist(err) } +// IsPermanentError returns true if the error is permanent. func (b *SMB) IsPermanentError(err error) bool { return util.IsPermanentError(err) } // Save stores data in the backend at the handle. -func (b *SMB) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) (err error) { +func (b *SMB) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) error { b.addSession() // Show session in use defer b.removeSession() - var cn *conn - cn, err = b.getConnection(b.ShareName) + cn, err := b.getConnection(b.ShareName) if err != nil { return err } @@ -157,31 +161,24 @@ func (b *SMB) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) MkDir: func(dir string) error { return cn.smbShare.MkdirAll(dir, b.Modes.Dir) }, - Remove: cn.smbShare.Remove, - IsMacENOTTY: func(error) bool { - return false - }, - Rename: cn.smbShare.Rename, - FsyncDir: func(_ string) error { - return nil - }, - SetFileReadonly: func(name string) error { - return cn.setFileReadonly(name, b.Modes.File) - }, - DirMode: b.Modes.Dir, - FileMode: b.Modes.File, + Remove: cn.smbShare.Remove, + IsMacENOTTY: func(error) bool { return false }, + Rename: cn.smbShare.Rename, + FsyncDir: func(_ string) error { return nil }, + SetFileReadonly: func(f string) error { return cn.setFileReadonly(f, b.Modes.File) }, + DirMode: b.Modes.Dir, + FileMode: b.Modes.File, } return util.SaveWithOptions(fileName, tmpFilename, rd, saveOptions) } -// set file to readonly +// setFileReadonly sets the file to read-only mode. func (cn *conn) setFileReadonly(f string, mode os.FileMode) error { return cn.smbShare.Chmod(f, mode&^0222) } -// Load runs fn with a reader that yields the contents of the file at h at the -// given offset. +// Load runs fn with a reader that yields the contents of the file at h at the given offset. func (b *SMB) Load(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { return util.DefaultLoad(ctx, h, length, offset, b.openReader, fn) } @@ -247,17 +244,15 @@ func (b *SMB) Delete(_ context.Context) error { // Close closes all open files. func (b *SMB) Close() error { - err := b.drainPool() - return err + return b.drainPool() } // tempSuffix generates a random string suffix that should be sufficiently long // to avoid accidental conflicts. func tempSuffix() string { var nonce [16]byte - _, err := rand.Read(nonce[:]) - if err != nil { - panic(err) + if _, err := rand.Read(nonce[:]); err != nil { + panic(errors.Wrap(err, "failed to generate random suffix")) } return hex.EncodeToString(nonce[:]) } diff --git a/internal/backend/util/file_helper.go b/internal/backend/util/file_helper.go index 60c72eca4..dd1b6e586 100644 --- a/internal/backend/util/file_helper.go +++ b/internal/backend/util/file_helper.go @@ -16,44 +16,15 @@ import ( "github.com/restic/restic/internal/fs" ) -// File is common interface for os.File and smb.File +// File is a common interface for os.File and smb.File type File interface { - Close() error + io.ReadWriteCloser + io.Seeker Name() string - Read(p []byte) (n int, err error) Readdir(count int) ([]os.FileInfo, error) Readdirnames(n int) ([]string, error) - Seek(offset int64, whence int) (int64, error) Stat() (os.FileInfo, error) Sync() error - Write(p []byte) (n int, err error) -} - -var errTooShort = fmt.Errorf("file is too short") - -func DeriveModesFromStat(l layout.Layout, statFn func(string) (os.FileInfo, error)) Modes { - fi, err := statFn(l.Filename(backend.Handle{Type: backend.ConfigFile})) - m := DeriveModesFromFileInfo(fi, err) - debug.Log("using (%03O file, %03O dir) permissions", m.File, m.Dir) - return m -} - -// Create creates all the necessary files and directories for a new local -// backend at dir. Afterwards a new config blob should be created. -func Create(fileName string, dirMode os.FileMode, paths []string, lstatFn func(string) (os.FileInfo, error), MkdirAllFn func(string, os.FileMode) error) error { - // test if config file already exists - _, err := lstatFn(fileName) - if err == nil { - return errors.New("config file already exists") - } - // create paths for data and refs - for _, d := range paths { - err := MkdirAllFn(d, dirMode) - if err != nil { - return errors.WithStack(err) - } - } - return nil } // SaveOptions contains options for saving files. @@ -69,6 +40,32 @@ type SaveOptions struct { FileMode os.FileMode } +var errTooShort = fmt.Errorf("file is too short") + +// DeriveModesFromStat derives file modes from the given layout and stat function. +func DeriveModesFromStat(l layout.Layout, statFn func(string) (os.FileInfo, error)) Modes { + fi, err := statFn(l.Filename(backend.Handle{Type: backend.ConfigFile})) + m := DeriveModesFromFileInfo(fi, err) + debug.Log("using (%03O file, %03O dir) permissions", m.File, m.Dir) + return m +} + +// Create creates all the necessary files and directories for a new local backend +// at dir. Afterwards a new config blob should be created. +func Create(fileName string, dirMode os.FileMode, paths []string, lstatFn func(string) (os.FileInfo, error), mkdirAllFn func(string, os.FileMode) error) error { + // test if config file already exists + if _, err := lstatFn(fileName); err == nil { + return errors.New("config file already exists") + } + // create paths for data and refs + for _, d := range paths { + if err := mkdirAllFn(d, dirMode); err != nil { + return errors.WithStack(err) + } + } + return nil +} + // SaveWithOptions stores data in the backend at the handle using the provided options. func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReader, options SaveOptions) (err error) { dir := filepath.Dir(fileName) @@ -81,14 +78,11 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade }() f, err := options.OpenTempFile(dir, tmpFilename) - if IsNotExist(err) { debug.Log("error %v: creating dir", err) - // error is caused by a missing directory, try to create it - mkdirErr := options.MkDir(dir) - if mkdirErr != nil { - debug.Log("error creating dir %v: %v", dir, mkdirErr) + if err := options.MkDir(dir); err != nil { + debug.Log("error creating dir %v: %v", dir, err) } else { // try again f, err = options.OpenTempFile(dir, tmpFilename) @@ -99,7 +93,7 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade return errors.WithStack(err) } - defer func(f File) { + defer func() { if err != nil { _ = f.Close() // Double Close is harmless. // Remove after Rename is harmless: we embed the final name in the @@ -108,24 +102,21 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade // goroutine. _ = options.Remove(f.Name()) } - }(f) + }() - if f, ok := f.(*os.File); ok { + if osFile, ok := f.(*os.File); ok { // preallocate disk space only for os.File if size := rd.Length(); size > 0 { - if err := fs.PreallocateFile(f, size); err != nil { + if err := fs.PreallocateFile(osFile, size); err != nil { debug.Log("Failed to preallocate %v with size %v: %v", fileName, size, err) } } } // save data, then sync - wbytes, err := io.Copy(f, rd) - if err != nil { + if wbytes, err := io.Copy(f, rd); err != nil { return errors.WithStack(err) - } - // sanity check - if wbytes != rd.Length() { + } else if wbytes != rd.Length() { // sanity check return errors.Errorf("wrote %d bytes instead of the expected %d bytes", wbytes, rd.Length()) } @@ -137,17 +128,17 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade } // Close, then rename. Windows doesn't like the reverse order. - if err = f.Close(); err != nil { + if err := f.Close(); err != nil { return errors.WithStack(err) } - if err = options.Rename(f.Name(), fileName); err != nil { + + if err := options.Rename(f.Name(), fileName); err != nil { return errors.WithStack(err) } // Now sync the directory to commit the Rename. if !syncNotSup { - err = options.FsyncDir(dir) - if err != nil { + if err := options.FsyncDir(dir); err != nil { return errors.WithStack(err) } } @@ -155,34 +146,33 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade // try to mark file as read-only to avoid accidental modifications // ignore if the operation fails as some filesystems don't allow the chmod call // e.g. exfat and network file systems with certain mount options - err = options.SetFileReadonly(fileName) - if err != nil && !os.IsPermission(err) { + if err := options.SetFileReadonly(fileName); err != nil && !os.IsPermission(err) { return errors.WithStack(err) } return nil } +// OpenReader opens a file for reading with the given parameters. func OpenReader(openFile func(string) (File, error), fileName string, length int, offset int64) (io.ReadCloser, error) { f, err := openFile(fileName) if err != nil { return nil, err } + fi, err := f.Stat() if err != nil { _ = f.Close() return nil, err } - size := fi.Size() - if size < offset+int64(length) { + if fi.Size() < offset+int64(length) { _ = f.Close() return nil, errTooShort } if offset > 0 { - _, err = f.Seek(offset, 0) - if err != nil { + if _, err = f.Seek(offset, 0); err != nil { _ = f.Close() return nil, err } @@ -205,13 +195,11 @@ func Stat(statFn func(string) (os.FileInfo, error), fileName, handleName string) } // Remove removes the blob with the given name and type. -func Remove(filename string, chmodfn func(string, os.FileMode) error, removeFn func(string) error) error { +func Remove(filename string, chmodFn func(string, os.FileMode) error, removeFn func(string) error) error { // reset read-only flag - err := chmodfn(filename, 0666) - if err != nil && !os.IsPermission(err) { + if err := chmodFn(filename, 0666); err != nil && !os.IsPermission(err) { return errors.WithStack(err) } - return removeFn(filename) } @@ -250,14 +238,12 @@ func visitDirs(ctx context.Context, openDir func(string) (File, error), dir stri return err } - err = d.Close() - if err != nil { + if err := d.Close(); err != nil { return err } for _, f := range sub { - err = visitFiles(ctx, openDir, filepath.Join(dir, f), fn, true) - if err != nil { + if err := visitFiles(ctx, openDir, filepath.Join(dir, f), fn, true); err != nil { return err } } @@ -287,8 +273,7 @@ func visitFiles(ctx context.Context, openDir func(string) (File, error), dir str return err } - err = d.Close() - if err != nil { + if err := d.Close(); err != nil { return err } @@ -299,11 +284,7 @@ func visitFiles(ctx context.Context, openDir func(string) (File, error), dir str default: } - err := fn(backend.FileInfo{ - Name: fi.Name(), - Size: fi.Size(), - }) - if err != nil { + if err := fn(backend.FileInfo{Name: fi.Name(), Size: fi.Size()}); err != nil { return err } }