diff --git a/cmd/serve/httplib/httplib.go b/cmd/serve/httplib/httplib.go index 03302e044..42d5392f2 100644 --- a/cmd/serve/httplib/httplib.go +++ b/cmd/serve/httplib/httplib.go @@ -97,8 +97,15 @@ type Options struct { Realm string // realm for authentication BasicUser string // single username for basic auth if not using Htpasswd BasicPass string // password for BasicUser + Auth AuthFn // custom Auth (not set by command line flags) } +// AuthFn if used will be used to authenticate user, pass. If an error +// is returned then the user is not authenticated. +// +// If a non nil value is returned then it is added to the context under the key +type AuthFn func(user, pass string) (value interface{}, err error) + // DefaultOpt is the default values used for Options var DefaultOpt = Options{ ListenAddr: "localhost:8080", @@ -123,9 +130,14 @@ type Server struct { type contextUserType struct{} -// ContextUserKey is a simple context key +// ContextUserKey is a simple context key for storing the username of the request var ContextUserKey = &contextUserType{} +type contextAuthType struct{} + +// ContextAuthKey is a simple context key for storing info returned by AuthFn +var ContextAuthKey = &contextAuthType{} + // singleUserProvider provides the encrypted password for a single user func (s *Server) singleUserProvider(user, realm string) string { if user == s.Opt.BasicUser { @@ -134,6 +146,27 @@ func (s *Server) singleUserProvider(user, realm string) string { return "" } +// parseAuthorization parses the Authorization header into user, pass +// it returns a boolean as to whether the parse was successful +func parseAuthorization(r *http.Request) (user, pass string, ok bool) { + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + s := strings.SplitN(authHeader, " ", 2) + if len(s) == 2 && s[0] == "Basic" { + b, err := base64.StdEncoding.DecodeString(s[1]) + if err == nil { + parts := strings.SplitN(string(b), ":", 2) + user = parts[0] + if len(parts) > 1 { + pass = parts[1] + ok = true + } + } + } + } + return +} + // NewServer creates an http server. The opt can be nil in which case // the default options will be used. func NewServer(handler http.Handler, opt *Options) *Server { @@ -149,17 +182,20 @@ func NewServer(handler http.Handler, opt *Options) *Server { } // Use htpasswd if required on everything - if s.Opt.HtPasswd != "" || s.Opt.BasicUser != "" { - var secretProvider auth.SecretProvider - if s.Opt.HtPasswd != "" { - fs.Infof(nil, "Using %q as htpasswd storage", s.Opt.HtPasswd) - secretProvider = auth.HtpasswdFileProvider(s.Opt.HtPasswd) - } else { - fs.Infof(nil, "Using --user %s --pass XXXX as authenticated user", s.Opt.BasicUser) - s.basicPassHashed = string(auth.MD5Crypt([]byte(s.Opt.BasicPass), []byte("dlPL2MqE"), []byte("$1$"))) - secretProvider = s.singleUserProvider + if s.Opt.HtPasswd != "" || s.Opt.BasicUser != "" || s.Opt.Auth != nil { + var authenticator *auth.BasicAuth + if s.Opt.Auth == nil { + var secretProvider auth.SecretProvider + if s.Opt.HtPasswd != "" { + fs.Infof(nil, "Using %q as htpasswd storage", s.Opt.HtPasswd) + secretProvider = auth.HtpasswdFileProvider(s.Opt.HtPasswd) + } else { + fs.Infof(nil, "Using --user %s --pass XXXX as authenticated user", s.Opt.BasicUser) + s.basicPassHashed = string(auth.MD5Crypt([]byte(s.Opt.BasicPass), []byte("dlPL2MqE"), []byte("$1$"))) + secretProvider = s.singleUserProvider + } + authenticator = auth.NewBasicAuthenticator(s.Opt.Realm, secretProvider) } - authenticator := auth.NewBasicAuthenticator(s.Opt.Realm, secretProvider) oldHandler := handler handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // No auth wanted for OPTIONS method @@ -167,26 +203,36 @@ func NewServer(handler http.Handler, opt *Options) *Server { oldHandler.ServeHTTP(w, r) return } - if username := authenticator.CheckAuth(r); username == "" { - authHeader := r.Header.Get(authenticator.Headers.V().Authorization) - if authHeader != "" { - s := strings.SplitN(authHeader, " ", 2) - var userName = "UNKNOWN" - if len(s) == 2 && s[0] == "Basic" { - b, err := base64.StdEncoding.DecodeString(s[1]) - if err == nil { - userName = strings.SplitN(string(b), ":", 2)[0] - } - } - fs.Infof(r.URL.Path, "%s: Unauthorized request from %s", r.RemoteAddr, userName) - } else { - fs.Infof(r.URL.Path, "%s: Basic auth challenge sent", r.RemoteAddr) - } - authenticator.RequireAuth(w, r) - } else { - r = r.WithContext(context.WithValue(r.Context(), ContextUserKey, username)) - oldHandler.ServeHTTP(w, r) + unauthorized := func() { + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("WWW-Authenticate", `Basic realm="`+s.Opt.Realm+`"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) } + user, pass, authValid := parseAuthorization(r) + if !authValid { + unauthorized() + return + } + if s.Opt.Auth == nil { + if username := authenticator.CheckAuth(r); username == "" { + fs.Infof(r.URL.Path, "%s: Unauthorized request from %s", r.RemoteAddr, user) + unauthorized() + return + } + } else { + // Custom Auth + value, err := s.Opt.Auth(user, pass) + if err != nil { + fs.Infof(r.URL.Path, "%s: Auth failed from %s: %v", r.RemoteAddr, user, err) + unauthorized() + return + } + if value != nil { + r = r.WithContext(context.WithValue(r.Context(), ContextAuthKey, value)) + } + } + r = r.WithContext(context.WithValue(r.Context(), ContextUserKey, user)) + oldHandler.ServeHTTP(w, r) }) s.usingAuth = true } diff --git a/cmd/serve/webdav/webdav.go b/cmd/serve/webdav/webdav.go index 0f4609b22..7e818a0f0 100644 --- a/cmd/serve/webdav/webdav.go +++ b/cmd/serve/webdav/webdav.go @@ -12,9 +12,11 @@ import ( "github.com/rclone/rclone/cmd/serve/httplib" "github.com/rclone/rclone/cmd/serve/httplib/httpflags" "github.com/rclone/rclone/cmd/serve/httplib/serve" + "github.com/rclone/rclone/cmd/serve/proxy" + "github.com/rclone/rclone/cmd/serve/proxy/proxyflags" "github.com/rclone/rclone/fs" "github.com/rclone/rclone/fs/hash" - "github.com/rclone/rclone/fs/log" + "github.com/rclone/rclone/lib/errors" "github.com/rclone/rclone/vfs" "github.com/rclone/rclone/vfs/vfsflags" "github.com/spf13/cobra" @@ -30,6 +32,7 @@ var ( func init() { httpflags.AddFlags(Command.Flags()) vfsflags.AddFlags(Command.Flags()) + proxyflags.AddFlags(Command.Flags()) Command.Flags().StringVar(&hashName, "etag-hash", "", "Which hash to use for the ETag, or auto or blank for off") Command.Flags().BoolVar(&disableGETDir, "disable-dir-list", false, "Disable HTML directory list on GET request for a directory") } @@ -57,10 +60,15 @@ supported hash on the backend or you can use a named hash such as Use "rclone hashsum" to see the full list. -` + httplib.Help + vfs.Help, +` + httplib.Help + vfs.Help + proxy.Help, RunE: func(command *cobra.Command, args []string) error { - cmd.CheckArgs(1, 1, command, args) - f := cmd.NewFsSrc(args) + var f fs.Fs + if proxyflags.Opt.AuthProxy == "" { + cmd.CheckArgs(1, 1, command, args) + f = cmd.NewFsSrc(args) + } else { + cmd.CheckArgs(0, 0, command, args) + } hashType = hash.None if hashName == "auto" { hashType = f.Hashes().GetOne() @@ -101,8 +109,9 @@ Use "rclone hashsum" to see the full list. type WebDAV struct { *httplib.Server f fs.Fs - vfs *vfs.VFS + _vfs *vfs.VFS // don't use directly, use getVFS webdavhandler *webdav.Handler + proxy *proxy.Proxy } // check interface @@ -111,8 +120,16 @@ var _ webdav.FileSystem = (*WebDAV)(nil) // Make a new WebDAV to serve the remote func newWebDAV(f fs.Fs, opt *httplib.Options) *WebDAV { w := &WebDAV{ - f: f, - vfs: vfs.New(f, &vfsflags.Opt), + f: f, + } + if proxyflags.Opt.AuthProxy != "" { + w.proxy = proxy.New(&proxyflags.Opt) + // override auth + copyOpt := *opt + copyOpt.Auth = w.auth + opt = ©Opt + } else { + w._vfs = vfs.New(f, &vfsflags.Opt) } w.Server = httplib.NewServer(http.HandlerFunc(w.handler), opt) webdavHandler := &webdav.Handler{ @@ -125,6 +142,31 @@ func newWebDAV(f fs.Fs, opt *httplib.Options) *WebDAV { return w } +// Gets the VFS in use for this request +func (w *WebDAV) getVFS(ctx context.Context) (VFS *vfs.VFS, err error) { + if w._vfs != nil { + return w._vfs, nil + } + value := ctx.Value(httplib.ContextAuthKey) + if value == nil { + return nil, errors.New("no VFS found in context") + } + VFS, ok := value.(*vfs.VFS) + if !ok { + return nil, errors.Errorf("context value is not VFS: %#v", value) + } + return VFS, nil +} + +// auth does proxy authorization +func (w *WebDAV) auth(user, pass string) (value interface{}, err error) { + VFS, _, err := w.proxy.Call(user, pass) + if err != nil { + return nil, err + } + return VFS, err +} + func (w *WebDAV) handler(rw http.ResponseWriter, r *http.Request) { urlPath, ok := w.Path(rw, r) if !ok { @@ -142,8 +184,14 @@ func (w *WebDAV) handler(rw http.ResponseWriter, r *http.Request) { // serveDir serves a directory index at dirRemote // This is similar to serveDir in serve http. func (w *WebDAV) serveDir(rw http.ResponseWriter, r *http.Request, dirRemote string) { + VFS, err := w.getVFS(r.Context()) + if err != nil { + http.Error(rw, "Root directory not found", http.StatusNotFound) + fs.Errorf(nil, "Failed to serve directory: %v", err) + return + } // List the directory - node, err := w.vfs.Stat(dirRemote) + node, err := VFS.Stat(dirRemote) if err == vfs.ENOENT { http.Error(rw, "Directory not found", http.StatusNotFound) return @@ -190,8 +238,12 @@ func (w *WebDAV) logRequest(r *http.Request, err error) { // Mkdir creates a directory func (w *WebDAV) Mkdir(ctx context.Context, name string, perm os.FileMode) (err error) { - defer log.Trace(name, "perm=%v", perm)("err = %v", &err) - dir, leaf, err := w.vfs.StatParent(name) + // defer log.Trace(name, "perm=%v", perm)("err = %v", &err) + VFS, err := w.getVFS(ctx) + if err != nil { + return err + } + dir, leaf, err := VFS.StatParent(name) if err != nil { return err } @@ -201,8 +253,12 @@ func (w *WebDAV) Mkdir(ctx context.Context, name string, perm os.FileMode) (err // OpenFile opens a file or a directory func (w *WebDAV) OpenFile(ctx context.Context, name string, flags int, perm os.FileMode) (file webdav.File, err error) { - defer log.Trace(name, "flags=%v, perm=%v", flags, perm)("err = %v", &err) - f, err := w.vfs.OpenFile(name, flags, perm) + // defer log.Trace(name, "flags=%v, perm=%v", flags, perm)("err = %v", &err) + VFS, err := w.getVFS(ctx) + if err != nil { + return nil, err + } + f, err := VFS.OpenFile(name, flags, perm) if err != nil { return nil, err } @@ -211,8 +267,12 @@ func (w *WebDAV) OpenFile(ctx context.Context, name string, flags int, perm os.F // RemoveAll removes a file or a directory and its contents func (w *WebDAV) RemoveAll(ctx context.Context, name string) (err error) { - defer log.Trace(name, "")("err = %v", &err) - node, err := w.vfs.Stat(name) + // defer log.Trace(name, "")("err = %v", &err) + VFS, err := w.getVFS(ctx) + if err != nil { + return err + } + node, err := VFS.Stat(name) if err != nil { return err } @@ -225,14 +285,22 @@ func (w *WebDAV) RemoveAll(ctx context.Context, name string) (err error) { // Rename a file or a directory func (w *WebDAV) Rename(ctx context.Context, oldName, newName string) (err error) { - defer log.Trace(oldName, "newName=%q", newName)("err = %v", &err) - return w.vfs.Rename(oldName, newName) + // defer log.Trace(oldName, "newName=%q", newName)("err = %v", &err) + VFS, err := w.getVFS(ctx) + if err != nil { + return err + } + return VFS.Rename(oldName, newName) } // Stat returns info about the file or directory func (w *WebDAV) Stat(ctx context.Context, name string) (fi os.FileInfo, err error) { - defer log.Trace(name, "")("fi=%+v, err = %v", &fi, &err) - fi, err = w.vfs.Stat(name) + // defer log.Trace(name, "")("fi=%+v, err = %v", &fi, &err) + VFS, err := w.getVFS(ctx) + if err != nil { + return nil, err + } + fi, err = VFS.Stat(name) if err != nil { return nil, err } @@ -274,7 +342,7 @@ type FileInfo struct { // ETag returns an ETag for the FileInfo func (fi FileInfo) ETag(ctx context.Context) (etag string, err error) { - defer log.Trace(fi, "")("etag=%q, err=%v", &etag, &err) + // defer log.Trace(fi, "")("etag=%q, err=%v", &etag, &err) if hashType == hash.None { return "", webdav.ErrNotImplemented } @@ -297,7 +365,7 @@ func (fi FileInfo) ETag(ctx context.Context) (etag string, err error) { // ContentType returns a content type for the FileInfo func (fi FileInfo) ContentType(ctx context.Context) (contentType string, err error) { - defer log.Trace(fi, "")("etag=%q, err=%v", &contentType, &err) + // defer log.Trace(fi, "")("etag=%q, err=%v", &contentType, &err) node, ok := (fi.FileInfo).(vfs.Node) if !ok { fs.Errorf(fi, "Expecting vfs.Node, got %T", fi.FileInfo)