1
0
Fork 0
mirror of https://github.com/restic/restic.git synced 2025-03-09 00:00:02 +01:00

Clean up code and fix documentation

This commit is contained in:
aneesh-n 2024-09-02 14:15:18 -06:00
parent b6123d3080
commit 3e7a369a4e
No known key found for this signature in database
GPG key ID: 6F5A52831C046F44
7 changed files with 190 additions and 234 deletions

View file

@ -639,7 +639,7 @@ You must first setup the following environment variables with the SMB credential
$ export RESTIC_SMB_DOMAIN=<MY_SMB_DOMAIN>
Once the server is configured, the setup of the SFTP repository can
Once the server is configured, the setup of the SMB repository can
simply be achieved by changing the URL scheme in the ``init`` command:
.. code-block:: console

View file

@ -15,16 +15,17 @@ import (
"github.com/restic/restic/internal/debug"
)
// Local is a backend in a local directory.
// Local represents a backend in a local directory.
type Local struct {
Config
layout.Layout
util.Modes
}
// ensure statically that *Local implements backend.Backend.
// Ensure statically that *Local implements backend.Backend.
var _ backend.Backend = &Local{}
// NewFactory returns a new factory for local backends.
func NewFactory() location.Factory {
return location.NewLimitedBackendFactory("local", ParseConfig, location.NoPassword, limiter.WrapBackendConstructor(Create), limiter.WrapBackendConstructor(Open))
}
@ -60,31 +61,33 @@ func Create(_ context.Context, cfg Config) (*Local, error) {
return be, nil
}
// Connections returns the number of configured connections.
func (b *Local) Connections() uint {
return b.Config.Connections
}
// Hasher may return a hash function for calculating a content hash for the backend
// Hasher returns a hash function for calculating a content hash for the backend.
func (b *Local) Hasher() hash.Hash {
return nil
}
// HasAtomicReplace returns whether Save() can atomically replace files
// HasAtomicReplace returns whether Save() can atomically replace files.
func (b *Local) HasAtomicReplace() bool {
return true
}
// IsNotExist returns true if the error is caused by a non existing file.
// IsNotExist returns true if the error is caused by a non-existing file.
func (b *Local) IsNotExist(err error) bool {
return util.IsNotExist(err)
}
// IsPermanentError returns true if the error is permanent.
func (b *Local) IsPermanentError(err error) bool {
return util.IsPermanentError(err)
}
// Save stores data in the backend at the handle.
func (b *Local) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) (err error) {
func (b *Local) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) error {
fileName := b.Filename(h)
// Create new file with a temporary name.
tmpFilename := filepath.Base(fileName) + "-tmp-"
@ -137,7 +140,7 @@ func (b *Local) Remove(_ context.Context, h backend.Handle) error {
// List runs fn for each file in the backend which has the type t. When an
// error occurs (or fn returns an error), List stops and returns it.
func (b *Local) List(ctx context.Context, t backend.FileType, fn func(backend.FileInfo) error) (err error) {
func (b *Local) List(ctx context.Context, t backend.FileType, fn func(backend.FileInfo) error) error {
openFunc := func(name string) (util.File, error) {
return os.Open(name)
}
@ -152,7 +155,7 @@ func (b *Local) Delete(_ context.Context) error {
// Close closes all open files.
func (b *Local) Close() error {
// this does not need to do anything, all open files are closed within the
// This does not need to do anything, all open files are closed within the
// same function.
return nil
}

View file

@ -12,6 +12,13 @@ import (
"github.com/restic/restic/internal/options"
)
const (
DefaultSMBPort int = 445 // DefaultSMBt returns the default port for SMB
DefaultDomain string = "WORKGROUP" // DefaultDomain returns the default domain for SMB
DefaultConnections uint = 5 // DefaultConnections returns the number of concurrent connections for SMB.
DefaultIdleTimeout time.Duration = 60 * time.Second // DefaultIdleTimeout returns the default max time before closing idle connections for SMB.
)
// Config contains all configuration necessary to connect to an SMB server
type Config struct {
Host string
@ -31,17 +38,10 @@ type Config struct {
ClientGUID string `option:"client-guid" help:"A 16-byte GUID to uniquely identify a client. If not specific a random GUID is used. (default: \"\")"`
}
const (
DefaultSmbPort int = 445 // DefaultSmbPort returns the default port for SMB
DefaultDomain string = "WORKGROUP" // DefaultDomain returns the default domain for SMB
DefaultConnections uint = 5 // DefaultConnections returns the number of concurrent connections for SMB.
DefaultIdleTimeout time.Duration = 60 * time.Second // DefaultIdleTimeout returns the default max time before closing idle connections for SMB.
)
// NewConfig returns a new Config with the default values filled in.
func NewConfig() Config {
return Config{
Port: DefaultSmbPort,
Port: DefaultSMBPort,
Domain: DefaultDomain,
IdleTimeout: DefaultIdleTimeout,
Connections: DefaultConnections,
@ -52,76 +52,52 @@ func init() {
options.Register("smb", Config{})
}
// ParseConfig parses the string s and extracts the s3 config. The
// ParseConfig parses the string s and extracts the SMB config. The
// supported configuration format is smb://[user@]host[:port]/sharename/directory.
// User and port are optional. Default port is 445.
func ParseConfig(s string) (*Config, error) {
var repo string
switch {
case strings.HasPrefix(s, "smb://"):
repo = s
case strings.HasPrefix(s, "smb:"):
repo = "smb://" + s[4:]
default:
hasSmbColonPrefix := strings.HasPrefix(s, "smb:")
if !hasSmbColonPrefix && !strings.HasPrefix(s, "smb://") {
return nil, errors.New("smb: invalid format")
}
var user, host, port, dir string
if hasSmbColonPrefix {
s = "smb://" + s[4:]
}
// parse the "smb://user@host/sharename/directory." url format
url, err := url.Parse(repo)
u, err := url.Parse(s)
if err != nil {
return nil, errors.WithStack(err)
}
if url.User != nil {
user = url.User.Username()
cfg := NewConfig()
if u.User != nil {
cfg.User = u.User.Username()
//Intentionally not allowing passwords to be set in url as
//it can cause issues when passwords have special characters
//like '@' and it is not recommended to pass passwords in the url.
}
host = url.Hostname()
if host == "" {
cfg.Host = u.Hostname()
if cfg.Host == "" {
return nil, errors.New("smb: invalid format, host name not found")
}
port = url.Port()
dir = url.Path
if dir == "" {
return nil, errors.Errorf("smb: invalid format, sharename/directory not found")
}
dir = dir[1:]
var portNum int
if port == "" {
portNum = DefaultSmbPort
} else {
var err error
portNum, err = strconv.Atoi(port)
if u.Port() != "" {
cfg.Port, err = strconv.Atoi(u.Port())
if err != nil {
return nil, err
return nil, errors.Wrap(err, "invalid port number")
}
}
sharename, directory, _ := strings.Cut(dir, "/")
return createConfig(user, host, portNum, sharename, directory)
}
func createConfig(user string, host string, port int, sharename, directory string) (*Config, error) {
if host == "" {
return nil, errors.New("smb: invalid format, Host not found")
if u.Path == "" {
return nil, errors.New("smb: invalid format, sharename/directory not found")
}
if directory != "" {
directory = path.Clean(directory)
}
cfg.ShareName, cfg.Path, _ = strings.Cut(u.Path[1:], "/")
cfg.Path = path.Clean(cfg.Path)
cfg := NewConfig()
cfg.User = user
cfg.Host = host
cfg.Port = port
cfg.ShareName = sharename
cfg.Path = directory
return &cfg, nil
}
@ -135,9 +111,9 @@ func (cfg *Config) ApplyEnvironment(prefix string) error {
}
if cfg.Domain == "" {
cfg.Domain = os.Getenv(prefix + "RESTIC_SMB_DOMAIN")
}
if cfg.Domain == "" {
cfg.Domain = DefaultDomain
if cfg.Domain == "" {
cfg.Domain = DefaultDomain
}
}
return nil
}

View file

@ -11,7 +11,7 @@ var configTests = []test.ConfigTestData[Config]{
{S: "smb://user@host/sharename/directory",
Cfg: Config{
Host: "host",
Port: DefaultSmbPort,
Port: DefaultSMBPort,
User: "user",
Domain: DefaultDomain,
ShareName: "sharename",
@ -33,7 +33,7 @@ var configTests = []test.ConfigTestData[Config]{
{S: "smb://host/sharename/directory",
Cfg: Config{
Host: "host",
Port: DefaultSmbPort,
Port: DefaultSMBPort,
Domain: DefaultDomain,
ShareName: "sharename",
Path: "directory",

View file

@ -33,52 +33,50 @@ import (
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
// conn encapsulates a SMB client and corresponding SMB client
// conn encapsulates a SMB client and corresponding SMB session and share.
type conn struct {
conn *net.Conn
netConn net.Conn
smbSession *smb2.Session
smbShare *smb2.Share
shareName string
}
func (c *conn) close() error {
var err, errLogoff error
var errs []error
if c.smbShare != nil {
err = c.smbShare.Umount()
errs = append(errs, c.smbShare.Umount())
}
if c.smbSession != nil {
errLogoff = c.smbSession.Logoff()
errs = append(errs, c.smbSession.Logoff())
}
return errors.Join(err, errLogoff)
return errors.Join(errs...)
}
// True if it's closed
func (c *conn) closed() bool {
var nopErr error
// isClosed checks if the connection is closed.
func (c *conn) isClosed() bool {
if c.smbShare != nil {
// stat the current directory
_, nopErr = c.smbShare.Stat(".")
} else {
// list the shares
_, nopErr = c.smbSession.ListSharenames()
_, err := c.smbShare.Stat(".")
return err != nil
}
return nopErr == nil
// list the shares
_, err := c.smbSession.ListSharenames()
return err != nil
}
// Show that we are using a SMB session
//
// Call removeSession() when done
// addSession increments the active session count when an SMB session needs to be used.
// If this is called, we must call removeSession when we are done using the session.
func (b *SMB) addSession() {
atomic.AddInt32(&b.sessions, 1)
}
// Show the SMB session is no longer in use
// removeSession decrements the active session count when it is no longer in use.
func (b *SMB) removeSession() {
atomic.AddInt32(&b.sessions, -1)
}
// getSessions shows whether there are any sessions in use
func (b *SMB) getSessions() int32 {
// getSessionCount returns the number of active sessions.
func (b *SMB) getSessionCount() int32 {
return atomic.LoadInt32(&b.sessions)
}
@ -87,15 +85,13 @@ func (b *SMB) getSessions() int32 {
// initiates the SMB handshake, and then returns a session for SMB communication.
func (b *SMB) dial(ctx context.Context, network, addr string) (*conn, error) {
dialer := net.Dialer{}
tconn, err := dialer.Dial(network, addr)
netConn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
var clientID [16]byte
if b.ClientGUID != "" {
copy(clientID[:], []byte(b.ClientGUID))
return nil, fmt.Errorf("SMB dial failed: %w", err)
}
clientID := b.getClientID()
d := &smb2.Dialer{
Negotiator: smb2.Negotiator{
RequireMessageSigning: b.RequireMessageSigning,
@ -109,26 +105,35 @@ func (b *SMB) dial(ctx context.Context, network, addr string) (*conn, error) {
},
}
session, err := d.DialContext(ctx, tconn)
session, err := d.DialContext(ctx, netConn)
if err != nil {
return nil, err
return nil, fmt.Errorf("SMB session initialization failed: %w", err)
}
return &conn{
netConn: netConn,
smbSession: session,
conn: &tconn,
}, nil
}
// Open a new connection to the SMB server.
func (b *SMB) newConnection(share string) (c *conn, err error) {
// getClientID returns the client GUID.
func (b *SMB) getClientID() [16]byte {
var clientID [16]byte
if b.ClientGUID != "" {
copy(clientID[:], []byte(b.ClientGUID))
}
return clientID
}
// newConnection creates a new SMB connection.
func (b *SMB) newConnection(share string) (*conn, error) {
// As we are pooling these connections we need to decouple
// them from the current context
ctx := context.Background()
c, err = b.dial(ctx, "tcp", b.Host+":"+strconv.Itoa(b.Port))
c, err := b.dial(ctx, "tcp", net.JoinHostPort(b.Host, strconv.Itoa(b.Port)))
if err != nil {
return nil, fmt.Errorf("couldn't connect SMB: %w", err)
return nil, fmt.Errorf("SMB connection failed: %w", err)
}
if share != "" {
@ -136,7 +141,7 @@ func (b *SMB) newConnection(share string) (c *conn, err error) {
c.smbShare, err = c.smbSession.Mount(share)
if err != nil {
_ = c.smbSession.Logoff()
return nil, fmt.Errorf("couldn't initialize SMB: %w", err)
return nil, fmt.Errorf("SMB share mount failed: %w", err)
}
c.smbShare = c.smbShare.WithContext(ctx)
}
@ -144,29 +149,31 @@ func (b *SMB) newConnection(share string) (c *conn, err error) {
return c, nil
}
// Ensure the specified share is mounted or the session is unmounted
func (c *conn) mountShare(share string) (err error) {
// mountShare ensures the existing share is unmounted and the specified share is mounted.
func (c *conn) mountShare(share string) error {
if c.shareName == share {
return nil
}
if c.smbShare != nil {
err = c.smbShare.Umount()
if err := c.smbShare.Umount(); err != nil {
// Check if we should not nil out the share for some errors
c.smbShare = nil
return err
}
c.smbShare = nil
}
if err != nil {
return
}
if share != "" {
var err error
c.smbShare, err = c.smbSession.Mount(share)
if err != nil {
return
return err
}
}
c.shareName = share
return nil
}
// Get a SMB connection from the pool, or open a new one
// getConnection retrieves a connection from the pool or creates a new one.
func (b *SMB) getConnection(share string) (c *conn, err error) {
b.poolMu.Lock()
for len(b.pool) > 0 {
@ -183,60 +190,54 @@ func (b *SMB) getConnection(share string) (c *conn, err error) {
if c != nil {
return c, nil
}
c, err = b.newConnection(share)
return c, err
return b.newConnection(share)
}
// put the connection back into the connection pool for reuse
// putConnection returns a connection to the pool for reuse.
func (b *SMB) putConnection(c *conn) {
if c == nil {
return
}
var nopErr error
if c.smbShare != nil {
// stat the current directory
_, nopErr = c.smbShare.Stat(".")
} else {
// list the shares
_, nopErr = c.smbSession.ListSharenames()
}
if nopErr != nil {
debug.Log("Connection failed, closing: %v", nopErr)
if c.isClosed() {
debug.Log("Connection closed, not returning to pool")
_ = c.close()
return
}
b.poolMu.Lock()
defer b.poolMu.Unlock()
b.pool = append(b.pool, c)
b.drain.Reset(b.IdleTimeout)
b.poolMu.Unlock()
}
// Drain the pool of any connections
func (b *SMB) drainPool() (err error) {
// drainPool closes all unused connections in the pool.
func (b *SMB) drainPool() error {
b.poolMu.Lock()
defer b.poolMu.Unlock()
if sessions := b.getSessions(); sessions != 0 {
if sessions := b.getSessionCount(); sessions != 0 {
debug.Log("Not closing %d unused connections as %d sessions active", len(b.pool), sessions)
b.drain.Reset(b.IdleTimeout) // nudge on the pool emptying timer
b.drain.Reset(b.IdleTimeout) // reset the timer to keep the pool open
return nil
}
if b.IdleTimeout > 0 {
b.drain.Stop()
}
if len(b.pool) != 0 {
debug.Log("Closing %d unused connections", len(b.pool))
debug.Log("Attempting to close %d unused connections", len(b.pool))
}
for i, c := range b.pool {
if !c.closed() {
cErr := c.close()
if cErr != nil {
err = cErr
var errs []error
for _, c := range b.pool {
if !c.isClosed() {
if err := c.close(); err != nil {
errs = append(errs, err)
}
}
b.pool[i] = nil
}
b.pool = nil
return err
return errors.Join(errs...)
}

View file

@ -18,6 +18,7 @@ import (
"github.com/restic/restic/internal/backend/location"
"github.com/restic/restic/internal/backend/util"
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/errors"
)
// Parts of this code have been adapted from Rclone (https://github.com/rclone)
@ -53,13 +54,15 @@ type SMB struct {
util.Modes
}
// ensure statically that *SMB implements backend.Backend.
// Ensure statically that *SMB implements backend.Backend interface.
var _ backend.Backend = &SMB{}
// NewFactory returns a new SMB backend factory.
func NewFactory() location.Factory {
return location.NewLimitedBackendFactory("smb", ParseConfig, location.NoPassword, limiter.WrapBackendConstructor(Create), limiter.WrapBackendConstructor(Open))
}
// open initializes a new SMB backend.
func open(cfg Config) (*SMB, error) {
l := layout.NewDefaultLayout(cfg.Path, filepath.Join)
@ -111,36 +114,37 @@ func Create(_ context.Context, cfg Config) (*SMB, error) {
return b, nil
}
// Connections returns the number of configured connections.
func (b *SMB) Connections() uint {
return b.Config.Connections
}
// Hasher may return a hash function for calculating a content hash for the backend
// Hasher returns a hash function for calculating a content hash for the backend.
func (b *SMB) Hasher() hash.Hash {
return nil
}
// HasAtomicReplace returns whether Save() can atomically replace files
// HasAtomicReplace returns whether Save() can atomically replace files.
func (b *SMB) HasAtomicReplace() bool {
return true
}
// IsNotExist returns true if the error is caused by a non existing file.
// IsNotExist returns true if the error is caused by a non-existing file.
func (b *SMB) IsNotExist(err error) bool {
return util.IsNotExist(err)
}
// IsPermanentError returns true if the error is permanent.
func (b *SMB) IsPermanentError(err error) bool {
return util.IsPermanentError(err)
}
// Save stores data in the backend at the handle.
func (b *SMB) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) (err error) {
func (b *SMB) Save(_ context.Context, h backend.Handle, rd backend.RewindReader) error {
b.addSession() // Show session in use
defer b.removeSession()
var cn *conn
cn, err = b.getConnection(b.ShareName)
cn, err := b.getConnection(b.ShareName)
if err != nil {
return err
}
@ -157,31 +161,24 @@ func (b *SMB) Save(_ context.Context, h backend.Handle, rd backend.RewindReader)
MkDir: func(dir string) error {
return cn.smbShare.MkdirAll(dir, b.Modes.Dir)
},
Remove: cn.smbShare.Remove,
IsMacENOTTY: func(error) bool {
return false
},
Rename: cn.smbShare.Rename,
FsyncDir: func(_ string) error {
return nil
},
SetFileReadonly: func(name string) error {
return cn.setFileReadonly(name, b.Modes.File)
},
DirMode: b.Modes.Dir,
FileMode: b.Modes.File,
Remove: cn.smbShare.Remove,
IsMacENOTTY: func(error) bool { return false },
Rename: cn.smbShare.Rename,
FsyncDir: func(_ string) error { return nil },
SetFileReadonly: func(f string) error { return cn.setFileReadonly(f, b.Modes.File) },
DirMode: b.Modes.Dir,
FileMode: b.Modes.File,
}
return util.SaveWithOptions(fileName, tmpFilename, rd, saveOptions)
}
// set file to readonly
// setFileReadonly sets the file to read-only mode.
func (cn *conn) setFileReadonly(f string, mode os.FileMode) error {
return cn.smbShare.Chmod(f, mode&^0222)
}
// Load runs fn with a reader that yields the contents of the file at h at the
// given offset.
// Load runs fn with a reader that yields the contents of the file at h at the given offset.
func (b *SMB) Load(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error {
return util.DefaultLoad(ctx, h, length, offset, b.openReader, fn)
}
@ -247,17 +244,15 @@ func (b *SMB) Delete(_ context.Context) error {
// Close closes all open files.
func (b *SMB) Close() error {
err := b.drainPool()
return err
return b.drainPool()
}
// tempSuffix generates a random string suffix that should be sufficiently long
// to avoid accidental conflicts.
func tempSuffix() string {
var nonce [16]byte
_, err := rand.Read(nonce[:])
if err != nil {
panic(err)
if _, err := rand.Read(nonce[:]); err != nil {
panic(errors.Wrap(err, "failed to generate random suffix"))
}
return hex.EncodeToString(nonce[:])
}

View file

@ -16,44 +16,15 @@ import (
"github.com/restic/restic/internal/fs"
)
// File is common interface for os.File and smb.File
// File is a common interface for os.File and smb.File
type File interface {
Close() error
io.ReadWriteCloser
io.Seeker
Name() string
Read(p []byte) (n int, err error)
Readdir(count int) ([]os.FileInfo, error)
Readdirnames(n int) ([]string, error)
Seek(offset int64, whence int) (int64, error)
Stat() (os.FileInfo, error)
Sync() error
Write(p []byte) (n int, err error)
}
var errTooShort = fmt.Errorf("file is too short")
func DeriveModesFromStat(l layout.Layout, statFn func(string) (os.FileInfo, error)) Modes {
fi, err := statFn(l.Filename(backend.Handle{Type: backend.ConfigFile}))
m := DeriveModesFromFileInfo(fi, err)
debug.Log("using (%03O file, %03O dir) permissions", m.File, m.Dir)
return m
}
// Create creates all the necessary files and directories for a new local
// backend at dir. Afterwards a new config blob should be created.
func Create(fileName string, dirMode os.FileMode, paths []string, lstatFn func(string) (os.FileInfo, error), MkdirAllFn func(string, os.FileMode) error) error {
// test if config file already exists
_, err := lstatFn(fileName)
if err == nil {
return errors.New("config file already exists")
}
// create paths for data and refs
for _, d := range paths {
err := MkdirAllFn(d, dirMode)
if err != nil {
return errors.WithStack(err)
}
}
return nil
}
// SaveOptions contains options for saving files.
@ -69,6 +40,32 @@ type SaveOptions struct {
FileMode os.FileMode
}
var errTooShort = fmt.Errorf("file is too short")
// DeriveModesFromStat derives file modes from the given layout and stat function.
func DeriveModesFromStat(l layout.Layout, statFn func(string) (os.FileInfo, error)) Modes {
fi, err := statFn(l.Filename(backend.Handle{Type: backend.ConfigFile}))
m := DeriveModesFromFileInfo(fi, err)
debug.Log("using (%03O file, %03O dir) permissions", m.File, m.Dir)
return m
}
// Create creates all the necessary files and directories for a new local backend
// at dir. Afterwards a new config blob should be created.
func Create(fileName string, dirMode os.FileMode, paths []string, lstatFn func(string) (os.FileInfo, error), mkdirAllFn func(string, os.FileMode) error) error {
// test if config file already exists
if _, err := lstatFn(fileName); err == nil {
return errors.New("config file already exists")
}
// create paths for data and refs
for _, d := range paths {
if err := mkdirAllFn(d, dirMode); err != nil {
return errors.WithStack(err)
}
}
return nil
}
// SaveWithOptions stores data in the backend at the handle using the provided options.
func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReader, options SaveOptions) (err error) {
dir := filepath.Dir(fileName)
@ -81,14 +78,11 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade
}()
f, err := options.OpenTempFile(dir, tmpFilename)
if IsNotExist(err) {
debug.Log("error %v: creating dir", err)
// error is caused by a missing directory, try to create it
mkdirErr := options.MkDir(dir)
if mkdirErr != nil {
debug.Log("error creating dir %v: %v", dir, mkdirErr)
if err := options.MkDir(dir); err != nil {
debug.Log("error creating dir %v: %v", dir, err)
} else {
// try again
f, err = options.OpenTempFile(dir, tmpFilename)
@ -99,7 +93,7 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade
return errors.WithStack(err)
}
defer func(f File) {
defer func() {
if err != nil {
_ = f.Close() // Double Close is harmless.
// Remove after Rename is harmless: we embed the final name in the
@ -108,24 +102,21 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade
// goroutine.
_ = options.Remove(f.Name())
}
}(f)
}()
if f, ok := f.(*os.File); ok {
if osFile, ok := f.(*os.File); ok {
// preallocate disk space only for os.File
if size := rd.Length(); size > 0 {
if err := fs.PreallocateFile(f, size); err != nil {
if err := fs.PreallocateFile(osFile, size); err != nil {
debug.Log("Failed to preallocate %v with size %v: %v", fileName, size, err)
}
}
}
// save data, then sync
wbytes, err := io.Copy(f, rd)
if err != nil {
if wbytes, err := io.Copy(f, rd); err != nil {
return errors.WithStack(err)
}
// sanity check
if wbytes != rd.Length() {
} else if wbytes != rd.Length() { // sanity check
return errors.Errorf("wrote %d bytes instead of the expected %d bytes", wbytes, rd.Length())
}
@ -137,17 +128,17 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade
}
// Close, then rename. Windows doesn't like the reverse order.
if err = f.Close(); err != nil {
if err := f.Close(); err != nil {
return errors.WithStack(err)
}
if err = options.Rename(f.Name(), fileName); err != nil {
if err := options.Rename(f.Name(), fileName); err != nil {
return errors.WithStack(err)
}
// Now sync the directory to commit the Rename.
if !syncNotSup {
err = options.FsyncDir(dir)
if err != nil {
if err := options.FsyncDir(dir); err != nil {
return errors.WithStack(err)
}
}
@ -155,34 +146,33 @@ func SaveWithOptions(fileName string, tmpFilename string, rd backend.RewindReade
// try to mark file as read-only to avoid accidental modifications
// ignore if the operation fails as some filesystems don't allow the chmod call
// e.g. exfat and network file systems with certain mount options
err = options.SetFileReadonly(fileName)
if err != nil && !os.IsPermission(err) {
if err := options.SetFileReadonly(fileName); err != nil && !os.IsPermission(err) {
return errors.WithStack(err)
}
return nil
}
// OpenReader opens a file for reading with the given parameters.
func OpenReader(openFile func(string) (File, error), fileName string, length int, offset int64) (io.ReadCloser, error) {
f, err := openFile(fileName)
if err != nil {
return nil, err
}
fi, err := f.Stat()
if err != nil {
_ = f.Close()
return nil, err
}
size := fi.Size()
if size < offset+int64(length) {
if fi.Size() < offset+int64(length) {
_ = f.Close()
return nil, errTooShort
}
if offset > 0 {
_, err = f.Seek(offset, 0)
if err != nil {
if _, err = f.Seek(offset, 0); err != nil {
_ = f.Close()
return nil, err
}
@ -205,13 +195,11 @@ func Stat(statFn func(string) (os.FileInfo, error), fileName, handleName string)
}
// Remove removes the blob with the given name and type.
func Remove(filename string, chmodfn func(string, os.FileMode) error, removeFn func(string) error) error {
func Remove(filename string, chmodFn func(string, os.FileMode) error, removeFn func(string) error) error {
// reset read-only flag
err := chmodfn(filename, 0666)
if err != nil && !os.IsPermission(err) {
if err := chmodFn(filename, 0666); err != nil && !os.IsPermission(err) {
return errors.WithStack(err)
}
return removeFn(filename)
}
@ -250,14 +238,12 @@ func visitDirs(ctx context.Context, openDir func(string) (File, error), dir stri
return err
}
err = d.Close()
if err != nil {
if err := d.Close(); err != nil {
return err
}
for _, f := range sub {
err = visitFiles(ctx, openDir, filepath.Join(dir, f), fn, true)
if err != nil {
if err := visitFiles(ctx, openDir, filepath.Join(dir, f), fn, true); err != nil {
return err
}
}
@ -287,8 +273,7 @@ func visitFiles(ctx context.Context, openDir func(string) (File, error), dir str
return err
}
err = d.Close()
if err != nil {
if err := d.Close(); err != nil {
return err
}
@ -299,11 +284,7 @@ func visitFiles(ctx context.Context, openDir func(string) (File, error), dir str
default:
}
err := fn(backend.FileInfo{
Name: fi.Name(),
Size: fi.Size(),
})
if err != nil {
if err := fn(backend.FileInfo{Name: fi.Name(), Size: fi.Size()}); err != nil {
return err
}
}