diff --git a/aead.go b/aead.go new file mode 100644 index 0000000..94cf30e --- /dev/null +++ b/aead.go @@ -0,0 +1,44 @@ +package fumitok + +import ( + "crypto/cipher" + "crypto/rand" + "encoding/binary" + "errors" +) + +var ( + ErrCipherTextTooShort = errors.New("ciphertext too short") +) + +func encode(aead cipher.AEAD, additional uint16, b []byte) []byte { + nsz := aead.NonceSize() + // Accocate capacity for all the stuffs. + buf := make([]byte, 2+nsz+len(b)+aead.Overhead()) + binary.LittleEndian.PutUint16(buf[:2], additional) + nonce := buf[2 : 2+nsz] + // Select a random nonce + _, err := rand.Read(nonce) + if err != nil { + panic(err) + } + // Encrypt the message and append the ciphertext to the nonce. + eb := aead.Seal(nonce[nsz:nsz], nonce, b, buf[:2]) + return nonce[:nsz+len(eb)] +} + +func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) { + nsz := aead.NonceSize() + if len(b) < nsz { + return nil, ErrCipherTextTooShort + } + // Split nonce and ciphertext. + nonce, ciphertext := b[:nsz], b[nsz:] + if len(ciphertext) == 0 { + return nil, nil + } + // Decrypt the message and check it wasn't tampered with. + var buf [2]byte + binary.LittleEndian.PutUint16(buf[:], additional) + return aead.Open(nil, nonce, ciphertext, buf[:]) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..091239a --- /dev/null +++ b/go.mod @@ -0,0 +1,7 @@ +module github.com/fumiama/fumitok + +go 1.22.1 + +require golang.org/x/crypto v0.28.0 + +require golang.org/x/sys v0.26.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..25fba75 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/helper.go b/helper.go new file mode 100644 index 0000000..1c895c0 --- /dev/null +++ b/helper.go @@ -0,0 +1,15 @@ +package fumitok + +import ( + "unsafe" +) + +// BytesToString 没有内存开销的转换 +func BytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// StringToBytes 没有内存开销的转换 +func StringToBytes(s string) (b []byte) { + return unsafe.Slice(unsafe.StringData(s), len(s)) +} diff --git a/token.go b/token.go new file mode 100644 index 0000000..8cd26a9 --- /dev/null +++ b/token.go @@ -0,0 +1,183 @@ +package fumitok + +import ( + "bytes" + "crypto/cipher" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "errors" + "hash/crc64" + "math/rand/v2" + "time" + + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + TokenLength = 88 +) + +var ( + ErrInvalidTokenKeySize = errors.New("invalid token key size") + ErrExpiredToken = errors.New("expired token") + ErrInvalidToken = errors.New("invalid token") + ErrInvalidTokenLength = errors.New("invalid token len") +) + +// Tokenizer xchacha20 对称加密密钥 +type Tokenizer struct { + aead cipher.AEAD + tabl *crc64.Table +} + +// NewTokenizer ... +func NewTokenizer(hexkeystr string) (t Tokenizer, err error) { + key, err := hex.DecodeString(hexkeystr) + if err != nil { + return + } + if len(key) != chacha20poly1305.KeySize { + err = ErrInvalidTokenKeySize + return + } + t.aead, err = chacha20poly1305.NewX(key) + if err != nil { + return + } + t.tabl = crc64.MakeTable(crc64.ECMA) + return +} + +// Generate 生成 token +// +// - id 用户标识符, 被加密 +// - expireAt 有效期至 +// - addt, mask 附加数据和其掩码, 不被加密 +func (t *Tokenizer) Generate(id uint64, expireAt time.Time, addt, mask uint16) (string, error) { + var buf [2 + 8 + 8 + 8]byte + text := buf[2:] + binary.LittleEndian.PutUint64(text[:8], uint64(expireAt.UnixMilli())) + binary.LittleEndian.PutUint64(text[8:16], id) + h := crc64.New(t.tabl) + _, err := h.Write(text[:16]) + if err != nil { + return "", err + } + _ = h.Sum(text[16:16]) + addt &= mask + addt |= (uint16(rand.Uint32()) & (^mask)) + binary.LittleEndian.PutUint16(buf[:2], addt) + w := bytes.NewBuffer(make([]byte, 0, 64)) + enc := base64.NewEncoder(base64.URLEncoding, w) + _, err = enc.Write(buf[:2]) + if err != nil { + return "", err + } + _, err = enc.Write(encode(t.aead, addt, text)) + if err != nil { + return "", err + } + err = enc.Close() + if err != nil { + return "", err + } + return BytesToString(w.Bytes()), nil +} + +// Validate 验证并提取信息 +// +// # 参数 +// - token 待验证凭据 +// - mask 附加数据之掩码, 将在返回时做掩模 +// - check 在解码前检查附加数据是否符合要求 +// +// # 返回 +// - uint64 用户标识符 +// - uint16 附加数据 +func (t *Tokenizer) Validate( + token string, mask uint16, checks ...func(uint16) error, +) (uint64, uint16, error) { + if len(token) != TokenLength { + return 0, 0, ErrInvalidTokenLength + } + data, err := base64.URLEncoding.DecodeString(token) + if err != nil { + return 0, 0, err + } + addt := binary.LittleEndian.Uint16(data[:2]) + addtmsk := addt & mask + for _, fn := range checks { + err = fn(addtmsk) + if err != nil { + return 0, 0, err + } + } + data, err = decode(t.aead, addt, data[2:]) + if err != nil { + return 0, 0, err + } + h := crc64.New(t.tabl) + _, err = h.Write(data[:16]) + if err != nil { + return 0, 0, err + } + crc := binary.BigEndian.Uint64(data[16:]) + if crc != h.Sum64() { + return 0, 0, ErrInvalidToken + } + if time.Now().After(time.UnixMilli(int64(binary.LittleEndian.Uint64(data[:8])))) { + return 0, 0, ErrExpiredToken + } + return binary.LittleEndian.Uint64(data[8:16]), addtmsk, nil +} + +// Refresh 过期时刷新 token +// +// - token 旧 token +// - expireAt 新的过期时间 +// - validAfter 旧 token 过期此时间段内仍可用于刷新 +// - mask 附加数据之掩码 +// - check 在解码前检查附加数据是否符合要求 +func (t *Tokenizer) Refresh( + token string, expireAt time.Time, validAfter time.Duration, + mask uint16, checks ...func(uint16) error, +) (string, error) { + if len(token) != TokenLength { + return "", ErrInvalidTokenLength + } + data, err := base64.URLEncoding.DecodeString(token) + if err != nil { + return "", err + } + addt := binary.LittleEndian.Uint16(data[:2]) + addtmsk := addt & mask + for _, fn := range checks { + err = fn(addtmsk) + if err != nil { + return "", err + } + } + data, err = decode(t.aead, addt, data[2:]) + if err != nil { + return "", err + } + h := crc64.New(t.tabl) + _, err = h.Write(data[:16]) + if err != nil { + return "", err + } + crc := binary.BigEndian.Uint64(data[16:]) + if crc != h.Sum64() { + return "", ErrInvalidToken + } + if time.Now().Add(validAfter).After( // still valid after 30 mins + time.UnixMilli(int64(binary.LittleEndian.Uint64(data[:8]))), + ) { + return "", ErrExpiredToken + } + return t.Generate( + binary.LittleEndian.Uint64(data[8:16]), + expireAt, addtmsk, mask, + ) +} diff --git a/token_test.go b/token_test.go new file mode 100644 index 0000000..bb292ac --- /dev/null +++ b/token_test.go @@ -0,0 +1,42 @@ +package fumitok + +import ( + "crypto/rand" + "encoding/hex" + "testing" + "time" +) + +func TestTokenizer(t *testing.T) { + var key [32]byte + _, err := rand.Read(key[:]) + if err != nil { + t.Fatal(err) + } + tk, err := NewTokenizer(hex.EncodeToString(key[:])) + if err != nil { + t.Fatal(err) + } + id := uint64(3719371987) + token, err := tk.Generate(id, time.Now().Add(time.Minute), 0x1234, 0x00ff) + if err != nil { + t.Fatal(err) + } + t.Log(token) + vid, addt, err := tk.Validate(token, 0x00ff) + if err != nil { + t.Fatal(err) + } + if vid != id || addt != 0x34 { + t.Fatal("id", id, "vid", vid, "addt", addt) + } + token, err = tk.Generate(id, time.Now().Add(-time.Minute), 0, 0) + if err != nil { + t.Fatal(err) + } + t.Log(token) + _, _, err = tk.Validate(token, 0) + if err != ErrExpiredToken { + t.Fatal("unexpected err", err) + } +}