diff --git a/access.go b/access.go new file mode 100644 index 0000000..bacea24 --- /dev/null +++ b/access.go @@ -0,0 +1,138 @@ +package main + +import ( + "encoding/json" + "errors" + "net/http" + "regexp" + "strings" + "time" + + "github.com/andreimarcu/linx-server/backends" + "github.com/flosch/pongo2" + "github.com/zenazn/goji/web" +) + +type accessKeySource int + +const ( + accessKeySourceNone accessKeySource = iota + accessKeySourceCookie + accessKeySourceHeader + accessKeySourceForm + accessKeySourceQuery +) + +const accessKeyHeaderName = "Linx-Access-Key" +const accessKeyParamName = "access_key" + +var ( + errInvalidAccessKey = errors.New("invalid access key") + + cliUserAgentRe = regexp.MustCompile("(?i)(lib)?curl|wget") +) + +func checkAccessKey(r *http.Request, metadata *backends.Metadata) (accessKeySource, error) { + key := metadata.AccessKey + if key == "" { + return accessKeySourceNone, nil + } + + cookieKey, err := r.Cookie(accessKeyHeaderName) + if err == nil { + if cookieKey.Value == key { + return accessKeySourceCookie, nil + } + return accessKeySourceCookie, errInvalidAccessKey + } + + headerKey := r.Header.Get(accessKeyHeaderName) + if headerKey == key { + return accessKeySourceHeader, nil + } else if headerKey != "" { + return accessKeySourceHeader, errInvalidAccessKey + } + + formKey := r.PostFormValue(accessKeyParamName) + if formKey == key { + return accessKeySourceForm, nil + } else if formKey != "" { + return accessKeySourceForm, errInvalidAccessKey + } + + queryKey := r.URL.Query().Get(accessKeyParamName) + if queryKey == key { + return accessKeySourceQuery, nil + } else if formKey != "" { + return accessKeySourceQuery, errInvalidAccessKey + } + + return accessKeySourceNone, errInvalidAccessKey +} + +func setAccessKeyCookies(w http.ResponseWriter, domain, fileName, value string, expires time.Time) { + cookie := http.Cookie{ + Name: accessKeyHeaderName, + Value: value, + HttpOnly: true, + Domain: domain, + Expires: expires, + } + + cookie.Path = Config.sitePath + fileName + http.SetCookie(w, &cookie) + + cookie.Path = Config.sitePath + Config.selifPath + fileName + http.SetCookie(w, &cookie) +} + +func fileAccessHandler(c web.C, w http.ResponseWriter, r *http.Request) { + if !Config.noDirectAgents && cliUserAgentRe.MatchString(r.Header.Get("User-Agent")) && !strings.EqualFold("application/json", r.Header.Get("Accept")) { + fileServeHandler(c, w, r) + return + } + + fileName := c.URLParams["name"] + + metadata, err := checkFile(fileName) + if err == backends.NotFoundErr { + notFoundHandler(c, w, r) + return + } else if err != nil { + oopsHandler(c, w, r, RespAUTO, "Corrupt metadata.") + return + } + + if src, err := checkAccessKey(r, &metadata); err != nil { + // remove invalid cookie + if src == accessKeySourceCookie { + setAccessKeyCookies(w, getSiteURL(r), fileName, "", time.Unix(0, 0)) + } + + if strings.EqualFold("application/json", r.Header.Get("Accept")) { + dec := json.NewEncoder(w) + _ = dec.Encode(map[string]string{ + "error": errInvalidAccessKey.Error(), + }) + + return + } + + _ = renderTemplate(Templates["access.html"], pongo2.Context{ + "filename": fileName, + "accesspath": fileName, + }, r, w) + + return + } + + if metadata.AccessKey != "" { + var expiry time.Time + if Config.accessKeyCookieExpiry != 0 { + expiry = time.Now().Add(time.Duration(Config.accessKeyCookieExpiry) * time.Second) + } + setAccessKeyCookies(w, getSiteURL(r), fileName, metadata.AccessKey, expiry) + } + + fileDisplayHandler(c, w, r, fileName, metadata) +} diff --git a/backends/localfs/localfs.go b/backends/localfs/localfs.go index 47187b6..42e32b8 100644 --- a/backends/localfs/localfs.go +++ b/backends/localfs/localfs.go @@ -19,6 +19,7 @@ type LocalfsBackend struct { type MetadataJSON struct { DeleteKey string `json:"delete_key"` + AccessKey string `json:"access_key,omitempty"` Sha256sum string `json:"sha256sum"` Mimetype string `json:"mimetype"` Size int64 `json:"size"` @@ -57,6 +58,7 @@ func (b LocalfsBackend) Head(key string) (metadata backends.Metadata, err error) } metadata.DeleteKey = mjson.DeleteKey + metadata.AccessKey = mjson.AccessKey metadata.Mimetype = mjson.Mimetype metadata.ArchiveFiles = mjson.ArchiveFiles metadata.Sha256sum = mjson.Sha256sum @@ -84,12 +86,13 @@ func (b LocalfsBackend) writeMetadata(key string, metadata backends.Metadata) er metaPath := path.Join(b.metaPath, key) mjson := MetadataJSON{ - DeleteKey: metadata.DeleteKey, - Mimetype: metadata.Mimetype, + DeleteKey: metadata.DeleteKey, + AccessKey: metadata.AccessKey, + Mimetype: metadata.Mimetype, ArchiveFiles: metadata.ArchiveFiles, - Sha256sum: metadata.Sha256sum, - Expiry: metadata.Expiry.Unix(), - Size: metadata.Size, + Sha256sum: metadata.Sha256sum, + Expiry: metadata.Expiry.Unix(), + Size: metadata.Size, } dst, err := os.Create(metaPath) @@ -108,7 +111,7 @@ func (b LocalfsBackend) writeMetadata(key string, metadata backends.Metadata) er return nil } -func (b LocalfsBackend) Put(key string, r io.Reader, expiry time.Time, deleteKey string) (m backends.Metadata, err error) { +func (b LocalfsBackend) Put(key string, r io.Reader, expiry time.Time, deleteKey, accessKey string) (m backends.Metadata, err error) { filePath := path.Join(b.filesPath, key) dst, err := os.Create(filePath) @@ -126,16 +129,17 @@ func (b LocalfsBackend) Put(key string, r io.Reader, expiry time.Time, deleteKey return m, err } - dst.Seek(0 ,0) + dst.Seek(0, 0) m, err = helpers.GenerateMetadata(dst) if err != nil { os.Remove(filePath) return } - dst.Seek(0 ,0) + dst.Seek(0, 0) m.Expiry = expiry m.DeleteKey = deleteKey + m.AccessKey = accessKey m.ArchiveFiles, _ = helpers.ListArchiveFiles(m.Mimetype, m.Size, dst) err = b.writeMetadata(key, m) diff --git a/backends/meta.go b/backends/meta.go index 7ba522d..b22276e 100644 --- a/backends/meta.go +++ b/backends/meta.go @@ -7,6 +7,7 @@ import ( type Metadata struct { DeleteKey string + AccessKey string Sha256sum string Mimetype string Size int64 diff --git a/backends/s3/s3.go b/backends/s3/s3.go index fc2a1b0..afdabf0 100644 --- a/backends/s3/s3.go +++ b/backends/s3/s3.go @@ -86,6 +86,7 @@ func mapMetadata(m backends.Metadata) map[string]*string { "Size": aws.String(strconv.FormatInt(m.Size, 10)), "Mimetype": aws.String(m.Mimetype), "Sha256sum": aws.String(m.Sha256sum), + "AccessKey": aws.String(m.AccessKey), } } @@ -104,10 +105,15 @@ func unmapMetadata(input map[string]*string) (m backends.Metadata, err error) { m.DeleteKey = aws.StringValue(input["Delete_key"]) m.Mimetype = aws.StringValue(input["Mimetype"]) m.Sha256sum = aws.StringValue(input["Sha256sum"]) + + if key, ok := input["AccessKey"]; ok { + m.AccessKey = aws.StringValue(key) + } + return } -func (b S3Backend) Put(key string, r io.Reader, expiry time.Time, deleteKey string) (m backends.Metadata, err error) { +func (b S3Backend) Put(key string, r io.Reader, expiry time.Time, deleteKey, accessKey string) (m backends.Metadata, err error) { tmpDst, err := ioutil.TempFile("", "linx-server-upload") if err != nil { return m, err @@ -133,6 +139,7 @@ func (b S3Backend) Put(key string, r io.Reader, expiry time.Time, deleteKey stri } m.Expiry = expiry m.DeleteKey = deleteKey + m.AccessKey = accessKey // XXX: we may not be able to write this to AWS easily //m.ArchiveFiles, _ = helpers.ListArchiveFiles(m.Mimetype, m.Size, tmpDst) diff --git a/backends/storage.go b/backends/storage.go index fdd8cd6..5d973c4 100644 --- a/backends/storage.go +++ b/backends/storage.go @@ -11,7 +11,7 @@ type StorageBackend interface { Exists(key string) (bool, error) Head(key string) (Metadata, error) Get(key string) (Metadata, io.ReadCloser, error) - Put(key string, r io.Reader, expiry time.Time, deleteKey string) (Metadata, error) + Put(key string, r io.Reader, expiry time.Time, deleteKey, accessKey string) (Metadata, error) PutMetadata(key string, m Metadata) error Size(key string) (int64, error) } diff --git a/display.go b/display.go index feb16da..e15b1b6 100644 --- a/display.go +++ b/display.go @@ -5,7 +5,6 @@ import ( "io/ioutil" "net/http" "path/filepath" - "regexp" "strconv" "strings" "time" @@ -21,24 +20,7 @@ import ( const maxDisplayFileSizeBytes = 1024 * 512 -var cliUserAgentRe = regexp.MustCompile("(?i)(lib)?curl|wget") - -func fileDisplayHandler(c web.C, w http.ResponseWriter, r *http.Request) { - if !Config.noDirectAgents && cliUserAgentRe.MatchString(r.Header.Get("User-Agent")) && !strings.EqualFold("application/json", r.Header.Get("Accept")) { - fileServeHandler(c, w, r) - return - } - - fileName := c.URLParams["name"] - - metadata, err := checkFile(fileName) - if err == backends.NotFoundErr { - notFoundHandler(c, w, r) - return - } else if err != nil { - oopsHandler(c, w, r, RespAUTO, "Corrupt metadata.") - return - } +func fileDisplayHandler(c web.C, w http.ResponseWriter, r *http.Request, fileName string, metadata backends.Metadata) { var expiryHuman string if metadata.Expiry != expiry.NeverExpire { expiryHuman = humanize.RelTime(time.Now(), metadata.Expiry, "", "") @@ -130,7 +112,7 @@ func fileDisplayHandler(c web.C, w http.ResponseWriter, r *http.Request) { tpl = Templates["display/file.html"] } - err = renderTemplate(tpl, pongo2.Context{ + err := renderTemplate(tpl, pongo2.Context{ "mime": metadata.Mimetype, "filename": fileName, "size": sizeHuman, diff --git a/fileserve.go b/fileserve.go index 202e477..27a28a9 100644 --- a/fileserve.go +++ b/fileserve.go @@ -27,6 +27,16 @@ func fileServeHandler(c web.C, w http.ResponseWriter, r *http.Request) { return } + if src, err := checkAccessKey(r, &metadata); err != nil { + // remove invalid cookie + if src == accessKeySourceCookie { + setAccessKeyCookies(w, getSiteURL(r), fileName, "", time.Unix(0, 0)) + } + unauthorizedHandler(c, w, r) + + return + } + if !Config.allowHotlink { referer := r.Header.Get("Referer") u, _ := url.Parse(referer) diff --git a/server.go b/server.go index 71a9c4d..50c4465 100644 --- a/server.go +++ b/server.go @@ -15,7 +15,7 @@ import ( "syscall" "time" - "github.com/GeertJohan/go.rice" + rice "github.com/GeertJohan/go.rice" "github.com/andreimarcu/linx-server/backends" "github.com/andreimarcu/linx-server/backends/localfs" "github.com/andreimarcu/linx-server/backends/s3" @@ -68,6 +68,7 @@ var Config struct { s3Bucket string s3ForcePathStyle bool forceRandomFilename bool + accessKeyCookieExpiry uint64 } var Templates = make(map[string]*pongo2.Template) @@ -200,7 +201,8 @@ func setup() *web.Mux { mux.Get(Config.sitePath+"static/*", staticHandler) mux.Get(Config.sitePath+"favicon.ico", staticHandler) mux.Get(Config.sitePath+"robots.txt", staticHandler) - mux.Get(nameRe, fileDisplayHandler) + mux.Get(nameRe, fileAccessHandler) + mux.Post(nameRe, fileAccessHandler) mux.Get(selifRe, fileServeHandler) mux.Get(selifIndexRe, unauthorizedHandler) mux.Get(torrentRe, fileTorrentHandler) @@ -273,6 +275,7 @@ func main() { "Force path-style addressing for S3 (e.g. https://s3.amazonaws.com/linx/example.txt)") flag.BoolVar(&Config.forceRandomFilename, "force-random-filename", false, "Force all uploads to use a random filename") + flag.Uint64Var(&Config.accessKeyCookieExpiry, "access-cookie-expiry", 0, "Expiration time for access key cookies in seconds (set 0 to use session cookies)") iniflags.Parse() diff --git a/templates.go b/templates.go index 79c90ce..acd9980 100644 --- a/templates.go +++ b/templates.go @@ -8,7 +8,7 @@ import ( "path/filepath" "strings" - "github.com/GeertJohan/go.rice" + rice "github.com/GeertJohan/go.rice" "github.com/flosch/pongo2" ) @@ -51,6 +51,7 @@ func populateTemplatesMap(tSet *pongo2.TemplateSet, tMap map[string]*pongo2.Temp "401.html", "404.html", "oops.html", + "access.html", "display/audio.html", "display/image.html", diff --git a/templates/access.html b/templates/access.html new file mode 100644 index 0000000..be3ada9 --- /dev/null +++ b/templates/access.html @@ -0,0 +1,11 @@ +{% extends "base.html" %} + +{% block content %} +