This commit is contained in:
源文雨
2024-10-13 19:43:12 +09:00
parent efbd9f6b20
commit 62c237674f
6 changed files with 295 additions and 0 deletions

44
aead.go Normal file
View File

@@ -0,0 +1,44 @@
package fumitok
import (
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
)
var (
ErrCipherTextTooShort = errors.New("ciphertext too short")
)
func encode(aead cipher.AEAD, additional uint16, b []byte) []byte {
nsz := aead.NonceSize()
// Accocate capacity for all the stuffs.
buf := make([]byte, 2+nsz+len(b)+aead.Overhead())
binary.LittleEndian.PutUint16(buf[:2], additional)
nonce := buf[2 : 2+nsz]
// Select a random nonce
_, err := rand.Read(nonce)
if err != nil {
panic(err)
}
// Encrypt the message and append the ciphertext to the nonce.
eb := aead.Seal(nonce[nsz:nsz], nonce, b, buf[:2])
return nonce[:nsz+len(eb)]
}
func decode(aead cipher.AEAD, additional uint16, b []byte) ([]byte, error) {
nsz := aead.NonceSize()
if len(b) < nsz {
return nil, ErrCipherTextTooShort
}
// Split nonce and ciphertext.
nonce, ciphertext := b[:nsz], b[nsz:]
if len(ciphertext) == 0 {
return nil, nil
}
// Decrypt the message and check it wasn't tampered with.
var buf [2]byte
binary.LittleEndian.PutUint16(buf[:], additional)
return aead.Open(nil, nonce, ciphertext, buf[:])
}

7
go.mod Normal file
View File

@@ -0,0 +1,7 @@
module github.com/fumiama/fumitok
go 1.22.1
require golang.org/x/crypto v0.28.0
require golang.org/x/sys v0.26.0 // indirect

4
go.sum Normal file
View File

@@ -0,0 +1,4 @@
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

15
helper.go Normal file
View File

@@ -0,0 +1,15 @@
package fumitok
import (
"unsafe"
)
// BytesToString 没有内存开销的转换
func BytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
// StringToBytes 没有内存开销的转换
func StringToBytes(s string) (b []byte) {
return unsafe.Slice(unsafe.StringData(s), len(s))
}

183
token.go Normal file
View File

@@ -0,0 +1,183 @@
package fumitok
import (
"bytes"
"crypto/cipher"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"errors"
"hash/crc64"
"math/rand/v2"
"time"
"golang.org/x/crypto/chacha20poly1305"
)
const (
TokenLength = 88
)
var (
ErrInvalidTokenKeySize = errors.New("invalid token key size")
ErrExpiredToken = errors.New("expired token")
ErrInvalidToken = errors.New("invalid token")
ErrInvalidTokenLength = errors.New("invalid token len")
)
// Tokenizer xchacha20 对称加密密钥
type Tokenizer struct {
aead cipher.AEAD
tabl *crc64.Table
}
// NewTokenizer ...
func NewTokenizer(hexkeystr string) (t Tokenizer, err error) {
key, err := hex.DecodeString(hexkeystr)
if err != nil {
return
}
if len(key) != chacha20poly1305.KeySize {
err = ErrInvalidTokenKeySize
return
}
t.aead, err = chacha20poly1305.NewX(key)
if err != nil {
return
}
t.tabl = crc64.MakeTable(crc64.ECMA)
return
}
// Generate 生成 token
//
// - id 用户标识符, 被加密
// - expireAt 有效期至
// - addt, mask 附加数据和其掩码, 不被加密
func (t *Tokenizer) Generate(id uint64, expireAt time.Time, addt, mask uint16) (string, error) {
var buf [2 + 8 + 8 + 8]byte
text := buf[2:]
binary.LittleEndian.PutUint64(text[:8], uint64(expireAt.UnixMilli()))
binary.LittleEndian.PutUint64(text[8:16], id)
h := crc64.New(t.tabl)
_, err := h.Write(text[:16])
if err != nil {
return "", err
}
_ = h.Sum(text[16:16])
addt &= mask
addt |= (uint16(rand.Uint32()) & (^mask))
binary.LittleEndian.PutUint16(buf[:2], addt)
w := bytes.NewBuffer(make([]byte, 0, 64))
enc := base64.NewEncoder(base64.URLEncoding, w)
_, err = enc.Write(buf[:2])
if err != nil {
return "", err
}
_, err = enc.Write(encode(t.aead, addt, text))
if err != nil {
return "", err
}
err = enc.Close()
if err != nil {
return "", err
}
return BytesToString(w.Bytes()), nil
}
// Validate 验证并提取信息
//
// # 参数
// - token 待验证凭据
// - mask 附加数据之掩码, 将在返回时做掩模
// - check 在解码前检查附加数据是否符合要求
//
// # 返回
// - uint64 用户标识符
// - uint16 附加数据
func (t *Tokenizer) Validate(
token string, mask uint16, checks ...func(uint16) error,
) (uint64, uint16, error) {
if len(token) != TokenLength {
return 0, 0, ErrInvalidTokenLength
}
data, err := base64.URLEncoding.DecodeString(token)
if err != nil {
return 0, 0, err
}
addt := binary.LittleEndian.Uint16(data[:2])
addtmsk := addt & mask
for _, fn := range checks {
err = fn(addtmsk)
if err != nil {
return 0, 0, err
}
}
data, err = decode(t.aead, addt, data[2:])
if err != nil {
return 0, 0, err
}
h := crc64.New(t.tabl)
_, err = h.Write(data[:16])
if err != nil {
return 0, 0, err
}
crc := binary.BigEndian.Uint64(data[16:])
if crc != h.Sum64() {
return 0, 0, ErrInvalidToken
}
if time.Now().After(time.UnixMilli(int64(binary.LittleEndian.Uint64(data[:8])))) {
return 0, 0, ErrExpiredToken
}
return binary.LittleEndian.Uint64(data[8:16]), addtmsk, nil
}
// Refresh 过期时刷新 token
//
// - token 旧 token
// - expireAt 新的过期时间
// - validAfter 旧 token 过期此时间段内仍可用于刷新
// - mask 附加数据之掩码
// - check 在解码前检查附加数据是否符合要求
func (t *Tokenizer) Refresh(
token string, expireAt time.Time, validAfter time.Duration,
mask uint16, checks ...func(uint16) error,
) (string, error) {
if len(token) != TokenLength {
return "", ErrInvalidTokenLength
}
data, err := base64.URLEncoding.DecodeString(token)
if err != nil {
return "", err
}
addt := binary.LittleEndian.Uint16(data[:2])
addtmsk := addt & mask
for _, fn := range checks {
err = fn(addtmsk)
if err != nil {
return "", err
}
}
data, err = decode(t.aead, addt, data[2:])
if err != nil {
return "", err
}
h := crc64.New(t.tabl)
_, err = h.Write(data[:16])
if err != nil {
return "", err
}
crc := binary.BigEndian.Uint64(data[16:])
if crc != h.Sum64() {
return "", ErrInvalidToken
}
if time.Now().Add(validAfter).After( // still valid after 30 mins
time.UnixMilli(int64(binary.LittleEndian.Uint64(data[:8]))),
) {
return "", ErrExpiredToken
}
return t.Generate(
binary.LittleEndian.Uint64(data[8:16]),
expireAt, addtmsk, mask,
)
}

42
token_test.go Normal file
View File

@@ -0,0 +1,42 @@
package fumitok
import (
"crypto/rand"
"encoding/hex"
"testing"
"time"
)
func TestTokenizer(t *testing.T) {
var key [32]byte
_, err := rand.Read(key[:])
if err != nil {
t.Fatal(err)
}
tk, err := NewTokenizer(hex.EncodeToString(key[:]))
if err != nil {
t.Fatal(err)
}
id := uint64(3719371987)
token, err := tk.Generate(id, time.Now().Add(time.Minute), 0x1234, 0x00ff)
if err != nil {
t.Fatal(err)
}
t.Log(token)
vid, addt, err := tk.Validate(token, 0x00ff)
if err != nil {
t.Fatal(err)
}
if vid != id || addt != 0x34 {
t.Fatal("id", id, "vid", vid, "addt", addt)
}
token, err = tk.Generate(id, time.Now().Add(-time.Minute), 0, 0)
if err != nil {
t.Fatal(err)
}
t.Log(token)
_, _, err = tk.Validate(token, 0)
if err != ErrExpiredToken {
t.Fatal("unexpected err", err)
}
}