diff --git a/cmd/main.go b/cmd/main.go index e7287a5..97bd7b6 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -3,6 +3,9 @@ package main import ( "bytes" "crypto/md5" + "crypto/rand" + "encoding/hex" + "errors" "flag" "io" "net/http" @@ -23,10 +26,32 @@ type imagebody struct { dat []byte } +var ( + errInvalidTokenLength = errors.New("invalid token length") + errInvalidToken = errors.New("invalid token") +) + func main() { cachetime := flag.Uint("t", 60, "cache time (s)") endpoint := flag.String("e", "127.0.0.1:8000", "listening endpoint") + var tok [32]byte + _, err := rand.Read(tok[:]) + if err != nil { + panic(err) + } + token := flag.String("k", hex.EncodeToString(tok[:]), "put/delete token") flag.Parse() + if len(*token) != 64 { + panic(errInvalidTokenLength) + } + n, err := hex.Decode(tok[:], imoto.StringToBytes(*token)) + if err != nil { + panic(err) + } + if n != 32 { + panic(errInvalidToken) + } + logrus.Infoln("listening to", *endpoint, "with token", hex.EncodeToString(tok[:])) imgcache = ttl.NewCache[uint64, *imagebody](time.Second * time.Duration(*cachetime)) http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { m, err := imoto.GetMD5(r.URL.Path) @@ -54,6 +79,15 @@ func main() { w.Header().Set("Content-Type", "image/"+img.typ) _, _ = w.Write(img.dat) case http.MethodPut: + err := checktoken(&tok, r) + if err != nil { + http.Error(w, "400 Bad Request: "+err.Error(), http.StatusBadRequest) + return + } + if imgcache.Get(p) != nil { + w.WriteHeader(http.StatusOK) + return + } data, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "500 Internal Server Error: "+err.Error(), http.StatusInternalServerError) @@ -76,6 +110,11 @@ func main() { }) w.WriteHeader(http.StatusOK) case http.MethodDelete: + err := checktoken(&tok, r) + if err != nil { + http.Error(w, "400 Bad Request: "+err.Error(), http.StatusBadRequest) + return + } img := imgcache.Get(p) if img == nil { w.WriteHeader(http.StatusNotFound) @@ -95,3 +134,18 @@ func main() { }) logrus.Errorln(http.ListenAndServe(*endpoint, nil)) } + +func checktoken(tok *[32]byte, r *http.Request) error { + t := r.Header.Get("Authorization") + if len(t) != 64 { + return errInvalidTokenLength + } + usrtok, err := hex.DecodeString(t) + if err != nil { + return err + } + if !bytes.Equal(usrtok, tok[:]) { + return errInvalidToken + } + return nil +} diff --git a/go.mod b/go.mod index 629d58b..7431751 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.20 require ( github.com/FloatTech/ttl v0.0.0-20230307105452-d6f7b2b647d1 github.com/fumiama/imgsz v0.0.2 + github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.9.3 ) diff --git a/go.sum b/go.sum index c8a639b..996d62d 100644 --- a/go.sum +++ b/go.sum @@ -5,6 +5,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fumiama/imgsz v0.0.2 h1:fAkC0FnIscdKOXwAxlyw3EUba5NzxZdSxGaq3Uyfxak= github.com/fumiama/imgsz v0.0.2/go.mod h1:dR71mI3I2O5u6+PCpd47M9TZptzP+39tRBcbdIkoqM4= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/helper.go b/helper.go index ff5cf78..1d31869 100644 --- a/helper.go +++ b/helper.go @@ -5,15 +5,39 @@ import ( "encoding/binary" "encoding/hex" "errors" + "runtime" "strconv" "strings" "unsafe" ) +func getFuncAndFileNameWithSkip(n int) (string, string) { + pc, fn, _, ok := runtime.Caller(n) + if !ok { + return "", "" + } + i := strings.LastIndex(fn, "/") + 1 + if i > 0 { + fn = strings.TrimSuffix(fn[i:], ".go") + } + fullname := runtime.FuncForPC(pc).Name() + i = strings.LastIndex(fullname, ".") + 1 + if i <= 0 || i >= len(fullname) { + return fullname, fn + } + return fullname[i:], fn +} + +// getThisFuncName 获取正在执行的函数名 +func getThisFuncName() string { + x, _ := getFuncAndFileNameWithSkip(2) + return x +} + func GetMD5(u string) (m [md5.Size]byte, err error) { u = strings.Trim(u, "/ ?&\n\t") if len(u) != md5.Size*2 && len(u) != md5.Size { - err = errors.New("invalid path len: " + strconv.Itoa(len(u))) + err = errors.New(getThisFuncName() + ": invalid path len: " + strconv.Itoa(len(u))) return } _, err = hex.Decode(m[:], StringToBytes(u)) diff --git a/image.go b/image.go new file mode 100644 index 0000000..c72da7d --- /dev/null +++ b/image.go @@ -0,0 +1,111 @@ +package imoto + +import ( + "bytes" + "crypto/md5" + "encoding/binary" + "encoding/hex" + "io" + "net/http" + "strings" + + "github.com/pkg/errors" +) + +var ( + API = "https://imoto.seku.su" +) + +var ( + ErrInvalidURL = errors.New("invalid URL") +) + +// Live judge if the image is exist +func Live(u string) bool { + req, err := http.NewRequest("HEAD", u, nil) + if err != nil { + return false + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false + } + return resp.StatusCode == http.StatusOK +} + +// Bed image to server +func Bed(t string, b []byte) (string, uint64, uint64, error) { + m := md5.Sum(b) + u := API + hex.EncodeToString(m[:]) + req, err := http.NewRequest("PUT", u, bytes.NewReader(b)) + if err != nil { + return "", 0, 0, errors.Wrap(err, getThisFuncName()) + } + req.Header.Add("Authorization", t) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", 0, 0, errors.Wrap(err, getThisFuncName()) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + msg, _ := io.ReadAll(resp.Body) + return "", 0, 0, errors.New(getThisFuncName() + ": " + BytesToString(msg)) + } + p, k := SplitMD5(m) + return u[:len(u)-16], p, k, nil +} + +// Use a URL once and delete it immediately +func Use(t string, u string, k uint64) ([]byte, error) { + i := strings.LastIndex(u, "/") + if i < 0 { + return nil, errors.Wrap(ErrInvalidURL, getThisFuncName()) + } + ms := u[i+1:] + var m [md5.Size]byte + switch len(ms) { + case 32: + n, err := hex.Decode(m[:], StringToBytes(ms)) + if err != nil { + return nil, errors.Wrap(err, getThisFuncName()) + } + if n != md5.Size { + return nil, errors.Wrap(ErrInvalidURL, getThisFuncName()) + } + case 16: + n, err := hex.Decode(m[:8], StringToBytes(ms)) + if err != nil { + return nil, errors.Wrap(err, getThisFuncName()) + } + if n != 8 { + return nil, errors.Wrap(ErrInvalidURL, getThisFuncName()) + } + binary.LittleEndian.PutUint64(m[8:], k) + u += Uint64String(k) + default: + return nil, errors.Wrap(ErrInvalidURL, getThisFuncName()) + } + req, err := http.NewRequest("DELETE", u, nil) + if err != nil { + return nil, errors.Wrap(err, getThisFuncName()) + } + req.Header.Add("Authorization", t) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, errors.Wrap(err, getThisFuncName()) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + msg, _ := io.ReadAll(resp.Body) + return nil, errors.New(getThisFuncName() + ": " + BytesToString(msg)) + } + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, getThisFuncName()) + } + m2 := md5.Sum(data) + if m2 != m { + return nil, errors.New(getThisFuncName() + ": expect " + hex.EncodeToString(m[:]) + " but got " + hex.EncodeToString(m2[:])) + } + return data, nil +} diff --git a/test/main.go b/test/main.go index e0d4f2f..ea6c87a 100644 --- a/test/main.go +++ b/test/main.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "crypto/md5" "encoding/hex" "io" @@ -17,8 +16,17 @@ func main() { panic(err) } m := md5.Sum(data) - p, _ := imoto.SplitMD5(m) - req, err := http.NewRequest("PUT", "http://127.0.0.1:8000/"+hex.EncodeToString(m[:]), bytes.NewReader(data)) + imoto.API = "http://127.0.0.1:8000/" + token := "0000000000000000000000000000000000000000000000000000000000000000" + u, _, k, err := imoto.Bed(token, data) + if err != nil { + panic(err) + } + isexist := imoto.Live(u) + if !isexist { + panic("HEAD") + } + req, err := http.NewRequest("GET", u, nil) if err != nil { panic(err) } @@ -29,32 +37,7 @@ func main() { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { msg, _ := io.ReadAll(resp.Body) - panic("PUT error: " + imoto.BytesToString(msg)) - } - req, err = http.NewRequest("HEAD", "http://127.0.0.1:8000/"+imoto.Uint64String(p), nil) - if err != nil { - panic(err) - } - resp, err = http.DefaultClient.Do(req) - if err != nil { - panic(err) - } - if resp.StatusCode != http.StatusOK { - msg, _ := io.ReadAll(resp.Body) - panic("HEAD error: " + imoto.BytesToString(msg)) - } - req, err = http.NewRequest("GET", "http://127.0.0.1:8000/"+imoto.Uint64String(p), nil) - if err != nil { - panic(err) - } - resp, err = http.DefaultClient.Do(req) - if err != nil { - panic(err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - msg, _ := io.ReadAll(resp.Body) - panic("HEAD error: " + imoto.BytesToString(msg)) + panic("GET error: " + imoto.BytesToString(msg)) } h := md5.New() _, err = io.Copy(h, resp.Body) @@ -66,26 +49,8 @@ func main() { if m2 != m { panic("GET error: expected " + hex.EncodeToString(m[:]) + " but got " + hex.EncodeToString(m2[:])) } - req, err = http.NewRequest("DELETE", "http://127.0.0.1:8000/"+hex.EncodeToString(m[:]), nil) + _, err = imoto.Use(token, u, k) if err != nil { panic(err) } - resp, err = http.DefaultClient.Do(req) - if err != nil { - panic(err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - msg, _ := io.ReadAll(resp.Body) - panic("HEAD error: " + imoto.BytesToString(msg)) - } - h = md5.New() - _, err = io.Copy(h, resp.Body) - if err != nil { - panic(err) - } - h.Sum(m2[:0]) - if m2 != m { - panic("DELETE error: expected " + hex.EncodeToString(m[:]) + " but got " + hex.EncodeToString(m2[:])) - } }