1
0
mirror of https://github.com/fumiama/NanoBot.git synced 2026-06-12 14:10:47 +08:00

feat: add v2 login method

This commit is contained in:
源文雨
2023-11-12 15:56:40 +09:00
parent c9f27c6514
commit 57b0a7c52c
4 changed files with 171 additions and 50 deletions

118
bot.go
View File

@@ -25,9 +25,10 @@ var (
// Bot 一个机器人实例的配置 // Bot 一个机器人实例的配置
type Bot struct { type Bot struct {
AppID string `yaml:"AppID"` // AppID is BotAppID开发者ID AppID string `yaml:"AppID"` // AppID is BotAppID开发者ID
Token string `yaml:"Token"` // Token is 机器人令牌 Token string `yaml:"Token"` // Token is 机器人令牌 有 Secret 则使用新版 API
Secret string `yaml:"Secret"` // Secret is 机器人密钥 token string // token 是通过 secret 获得的残血 token
Secret string `yaml:"Secret"` // Secret is 机器人令牌 V2 (AppSecret/ClientSecret) 沙盒目前虽然能登录但无法收发消息
SuperUsers []string `yaml:"SuperUsers"` // SuperUsers 超级用户 SuperUsers []string `yaml:"SuperUsers"` // SuperUsers 超级用户
Timeout time.Duration `yaml:"Timeout"` // Timeout is API 调用超时 Timeout time.Duration `yaml:"Timeout"` // Timeout is API 调用超时
Handler *Handler `yaml:"-"` // Handler 注册对各种事件的处理 Handler *Handler `yaml:"-"` // Handler 注册对各种事件的处理
@@ -40,10 +41,12 @@ type Bot struct {
gateway string // gateway 获得的网关 gateway string // gateway 获得的网关
seq uint32 // seq 最新的 s seq uint32 // seq 最新的 s
heartbeat uint32 // heartbeat 心跳周期, 单位毫秒 heartbeat uint32 // heartbeat 心跳周期, 单位毫秒
expiresec int64 // expiresec Token 有效时间
handlers map[string]eventHandlerType // handlers 方便调用的 handler handlers map[string]eventHandlerType // handlers 方便调用的 handler
mu sync.Mutex // 写锁 mu sync.Mutex // 写锁
conn *websocket.Conn // conn 目前的 wss 连接 conn *websocket.Conn // conn 目前的 wss 连接
hbonce sync.Once // hbonce 保证仅执行一次 heartbeat hbonce sync.Once // hbonce 保证仅执行一次 heartbeat
exonce sync.Once // exonce 保证仅执行一次刷新 token
client *http.Client // client 主要配置 timeout client *http.Client // client 主要配置 timeout
ready EventReady // ready 连接成功后下发的 bot 基本信息 ready EventReady // ready 连接成功后下发的 bot 基本信息
@@ -55,33 +58,37 @@ func (ctx *Ctx) GetReady() *EventReady {
} }
// getinitinfo 获得 gateway 和 shard // getinitinfo 获得 gateway 和 shard
func (b *Bot) getinitinfo() (gw string, shard [2]byte, err error) { func (bot *Bot) getinitinfo() (secret, gw string, shard [2]byte, err error) {
shard[1] = 1 shard[1] = 1
if b.client == nil { if bot.client == nil {
b.client = http.DefaultClient bot.client = http.DefaultClient
} }
if b.ShardIndex == 0 { secret = bot.Secret
gw, err = b.GetGeneralWSSGateway() if bot.Secret != "" {
bot.Secret = ""
}
if bot.ShardIndex == 0 {
gw, err = bot.GetGeneralWSSGateway()
if err != nil { if err != nil {
return return
} }
} else { } else {
var sgw *ShardWSSGateway var sgw *ShardWSSGateway
sgw, err = b.GetShardWSSGateway() sgw, err = bot.GetShardWSSGateway()
if err != nil { if err != nil {
return return
} }
if b.ShardCount == 0 { if bot.ShardCount == 0 {
log.Infoln(getLogHeader(), "使用网关推荐Shards数:", sgw.Shards) log.Infoln(getLogHeader(), "使用网关推荐Shards数:", sgw.Shards)
b.ShardCount = uint8(sgw.Shards) bot.ShardCount = uint8(sgw.Shards)
} }
if b.ShardCount <= b.ShardIndex { if bot.ShardCount <= bot.ShardIndex {
err = errors.New("shard index " + strconv.Itoa(int(b.ShardIndex)) + " >= suggested size " + strconv.Itoa(sgw.Shards)) err = errors.New("shard index " + strconv.Itoa(int(bot.ShardIndex)) + " >= suggested size " + strconv.Itoa(sgw.Shards))
return return
} }
gw = sgw.URL gw = sgw.URL
shard[0] = byte(b.ShardIndex) shard[0] = byte(bot.ShardIndex)
shard[1] = byte(b.ShardCount) shard[1] = byte(bot.ShardCount)
} }
return return
} }
@@ -92,11 +99,11 @@ func Start(bots ...*Bot) error {
log.Warnln(getLogHeader(), "已忽略重复调用的", getThisFuncName()) log.Warnln(getLogHeader(), "已忽略重复调用的", getThisFuncName())
} }
for _, b := range bots { for _, b := range bots {
gw, shard, err := b.getinitinfo() s, gw, shard, err := b.getinitinfo()
if err != nil { if err != nil {
return err return err
} }
go b.Init(gw, shard).Connect().Listen() go b.Init(s, gw, shard).Connect().Listen()
} }
return nil return nil
} }
@@ -112,25 +119,25 @@ func Run(preblock func(), bots ...*Bot) error {
return nil return nil
case 1: case 1:
b = bots[0] b = bots[0]
gw, shard, err := b.getinitinfo() s, gw, shard, err := b.getinitinfo()
if err != nil { if err != nil {
return err return err
} }
b.Init(gw, shard) b.Init(s, gw, shard)
default: default:
for _, b := range bots[:len(bots)-1] { for _, b := range bots[:len(bots)-1] {
gw, shard, err := b.getinitinfo() s, gw, shard, err := b.getinitinfo()
if err != nil { if err != nil {
return err return err
} }
go b.Init(gw, shard).Connect().Listen() go b.Init(s, gw, shard).Connect().Listen()
} }
b = bots[len(bots)-1] b = bots[len(bots)-1]
gw, shard, err := b.getinitinfo() s, gw, shard, err := b.getinitinfo()
if err != nil { if err != nil {
return err return err
} }
b.Init(gw, shard) b.Init(s, gw, shard)
} }
b.Connect() b.Connect()
if preblock != nil { if preblock != nil {
@@ -141,19 +148,19 @@ func Run(preblock func(), bots ...*Bot) error {
} }
// Init 初始化, 只需执行一次 // Init 初始化, 只需执行一次
func (b *Bot) Init(gateway string, shard [2]byte) *Bot { func (bot *Bot) Init(secret, gateway string, shard [2]byte) *Bot {
b.gateway = gateway bot.gateway = gateway
b.shard = shard bot.shard = shard
if b.Timeout == 0 { if bot.Timeout == 0 {
b.Timeout = time.Minute bot.Timeout = time.Minute
} }
b.client = &http.Client{ bot.client = &http.Client{
Timeout: b.Timeout, Timeout: bot.Timeout,
} }
if b.Handler != nil { if bot.Handler != nil {
h := reflect.ValueOf(b.Handler).Elem() h := reflect.ValueOf(bot.Handler).Elem()
t := h.Type() t := h.Type()
b.handlers = make(map[string]eventHandlerType, h.NumField()*4) bot.handlers = make(map[string]eventHandlerType, h.NumField()*4)
for i := 0; i < h.NumField(); i++ { for i := 0; i < h.NumField(); i++ {
f := h.Field(i) f := h.Field(i)
if f.IsZero() { if f.IsZero() {
@@ -162,17 +169,40 @@ func (b *Bot) Init(gateway string, shard [2]byte) *Bot {
tp := t.Field(i).Name[2:] // skip On tp := t.Field(i).Name[2:] // skip On
log.Infoln(getLogHeader(), "注册处理函数", tp) log.Infoln(getLogHeader(), "注册处理函数", tp)
handler := f.Interface() handler := f.Interface()
b.handlers[tp] = eventHandlerType{ bot.handlers[tp] = eventHandlerType{
h: *(*generalHandleType)(unsafe.Add(unsafe.Pointer(&handler), unsafe.Sizeof(uintptr(0)))), h: *(*generalHandleType)(unsafe.Add(unsafe.Pointer(&handler), unsafe.Sizeof(uintptr(0)))),
t: t.Field(i).Type.In(2).Elem(), t: t.Field(i).Type.In(2).Elem(),
} }
} }
} }
return b bot.Secret = secret
if bot.IsV2() {
for {
err := bot.GetAppAccessToken()
if err == nil {
log.Infoln(getLogHeader(), "获得 Token: "+bot.token+", 超时:", bot.expiresec, "秒")
bot.exonce.Do(func() {
go bot.refreshtoken()
})
break
}
log.Infoln(getLogHeader(), "获得 Token 失败:", err)
time.Sleep(time.Second * 3)
}
}
return bot
}
// IsV2 判断是否运行于 V2 API 下
func (bot *Bot) IsV2() bool {
return bot.Secret != ""
} }
// Authorization 返回 Authorization Header value // Authorization 返回 Authorization Header value
func (bot *Bot) Authorization() string { func (bot *Bot) Authorization() string {
if bot.IsV2() {
return "QQBot " + bot.token
}
return "Bot " + bot.AppID + "." + bot.Token return "Bot " + bot.AppID + "." + bot.Token
} }
@@ -271,6 +301,24 @@ func (bot *Bot) Connect() *Bot {
return bot return bot
} }
// refreshtoken 以 Expire 为间隔刷新 Token
func (bot *Bot) refreshtoken() {
for {
time.Sleep(time.Second * 10)
if atomic.LoadUint32(&bot.heartbeat) == 0 {
log.Warnln(getLogHeader(), "等待服务器建立连接...")
continue
}
time.Sleep(time.Duration(bot.expiresec) * time.Second)
err := bot.GetAppAccessToken()
if err != nil {
log.Warnln(getLogHeader(), "刷新 Token 时出现错误:", err)
} else {
log.Infoln(getLogHeader(), "刷新 Token: "+bot.token+", 超时:", bot.expiresec, "秒")
}
}
}
// doheartbeat 按指定间隔进行心跳包发送 // doheartbeat 按指定间隔进行心跳包发送
func (bot *Bot) doheartbeat() { func (bot *Bot) doheartbeat() {
payload := struct { payload := struct {

33
http.go
View File

@@ -19,14 +19,19 @@ import (
) )
// HTTPRequsetConstructer ... // HTTPRequsetConstructer ...
type HTTPRequsetConstructer func(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) type HTTPRequsetConstructer func(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error)
func newHTTPEndpointRequestWithAuth(method, contenttype, ep string, auth string, body io.Reader) (req *http.Request, err error) { func newHTTPEndpointRequestWithAuth(method, contenttype, ep string, auth, appid string, body io.Reader) (req *http.Request, err error) {
req, err = http.NewRequest(method, OpenAPI+ep, body) req, err = http.NewRequest(method, ep, body)
if err != nil { if err != nil {
return return
} }
req.Header.Add("Authorization", auth) if auth != "" {
req.Header.Add("Authorization", auth)
}
if appid != "" {
req.Header.Add("X-Union-Appid", appid)
}
if contenttype == "" { if contenttype == "" {
contenttype = "application/json" contenttype = "application/json"
} }
@@ -35,28 +40,28 @@ func newHTTPEndpointRequestWithAuth(method, contenttype, ep string, auth string,
} }
// NewHTTPEndpointGetRequestWithAuth 新建带鉴权头的 HTTP GET 请求 // NewHTTPEndpointGetRequestWithAuth 新建带鉴权头的 HTTP GET 请求
func NewHTTPEndpointGetRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { func NewHTTPEndpointGetRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) {
return newHTTPEndpointRequestWithAuth("GET", contenttype, ep, auth, body) return newHTTPEndpointRequestWithAuth("GET", contenttype, OpenAPI+ep, auth, appid, body)
} }
// NewHTTPEndpointPutRequestWithAuth 新建带鉴权头的 HTTP PUT 请求 // NewHTTPEndpointPutRequestWithAuth 新建带鉴权头的 HTTP PUT 请求
func NewHTTPEndpointPutRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { func NewHTTPEndpointPutRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) {
return newHTTPEndpointRequestWithAuth("PUT", contenttype, ep, auth, body) return newHTTPEndpointRequestWithAuth("PUT", contenttype, OpenAPI+ep, auth, appid, body)
} }
// NewHTTPEndpointDeleteRequestWithAuth 新建带鉴权头的 HTTP DELETE 请求 // NewHTTPEndpointDeleteRequestWithAuth 新建带鉴权头的 HTTP DELETE 请求
func NewHTTPEndpointDeleteRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { func NewHTTPEndpointDeleteRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) {
return newHTTPEndpointRequestWithAuth("DELETE", contenttype, ep, auth, body) return newHTTPEndpointRequestWithAuth("DELETE", contenttype, OpenAPI+ep, auth, appid, body)
} }
// NewHTTPEndpointPostRequestWithAuth 新建带鉴权头的 HTTP POST 请求 // NewHTTPEndpointPostRequestWithAuth 新建带鉴权头的 HTTP POST 请求
func NewHTTPEndpointPostRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { func NewHTTPEndpointPostRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) {
return newHTTPEndpointRequestWithAuth("POST", contenttype, ep, auth, body) return newHTTPEndpointRequestWithAuth("POST", contenttype, OpenAPI+ep, auth, appid, body)
} }
// NewHTTPEndpointPatchRequestWithAuth 新建带鉴权头的 HTTP PATCH 请求 // NewHTTPEndpointPatchRequestWithAuth 新建带鉴权头的 HTTP PATCH 请求
func NewHTTPEndpointPatchRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { func NewHTTPEndpointPatchRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) {
return newHTTPEndpointRequestWithAuth("PATCH", contenttype, ep, auth, body) return newHTTPEndpointRequestWithAuth("PATCH", contenttype, OpenAPI+ep, auth, appid, body)
} }
// WriteHTTPQueryIfNotNil 如果非空则将请求添加到 baseurl 后 // WriteHTTPQueryIfNotNil 如果非空则将请求添加到 baseurl 后

View File

@@ -17,6 +17,8 @@ const (
StandardAPI = `https://api.sgroup.qq.com` StandardAPI = `https://api.sgroup.qq.com`
// SandboxAPI 沙箱环境接口域名 // SandboxAPI 沙箱环境接口域名
SandboxAPI = `https://sandbox.api.sgroup.qq.com` SandboxAPI = `https://sandbox.api.sgroup.qq.com`
// AccessTokenAPI 获取接口凭证的 API
AccessTokenAPI = "https://bots.qq.com/app/getAppAccessToken"
) )
var ( var (
@@ -30,7 +32,11 @@ type CodeMessageBase struct {
} }
func (bot *Bot) dohttprequest(constructer HTTPRequsetConstructer, ep, contenttype string, ptr any, body io.Reader) error { func (bot *Bot) dohttprequest(constructer HTTPRequsetConstructer, ep, contenttype string, ptr any, body io.Reader) error {
req, err := constructer(ep, contenttype, bot.Authorization(), body) appid := ""
if bot.IsV2() {
appid = bot.AppID
}
req, err := constructer(ep, contenttype, bot.Authorization(), appid, body)
if err != nil { if err != nil {
return errors.Wrap(err, getCallerFuncName()) return errors.Wrap(err, getCallerFuncName())
} }

View File

@@ -1,5 +1,18 @@
package nano package nano
import (
"encoding/json"
"net/http"
"strconv"
"github.com/pkg/errors"
)
var (
ErrEmptyToken = errors.New("empty token")
ErrInvalidExpire = errors.New("invalid expire")
)
// GetGeneralWSSGateway 获取通用 WSS 接入点 // GetGeneralWSSGateway 获取通用 WSS 接入点
// //
// https://bot.q.qq.com/wiki/develop/api/openapi/wss/url_get.html // https://bot.q.qq.com/wiki/develop/api/openapi/wss/url_get.html
@@ -32,3 +45,52 @@ type ShardWSSGateway struct {
func (bot *Bot) GetShardWSSGateway() (*ShardWSSGateway, error) { func (bot *Bot) GetShardWSSGateway() (*ShardWSSGateway, error) {
return bot.getOpenAPIofShardWSSGateway("/gateway/bot") return bot.getOpenAPIofShardWSSGateway("/gateway/bot")
} }
// GetAppAccessToken 获取接口凭证并保存到 bot.Token
//
// https://bot.q.qq.com/wiki/develop/api-231017/dev-prepare/interface-framework/api-use.html#%E8%8E%B7%E5%8F%96%E6%8E%A5%E5%8F%A3%E5%87%AD%E8%AF%81
func (bot *Bot) GetAppAccessToken() error {
req, err := newHTTPEndpointRequestWithAuth("POST", "", AccessTokenAPI, "", "", WriteBodyFromJSON(&struct {
A string `json:"appId"`
S string `json:"clientSecret"`
}{bot.AppID, bot.Secret}))
if err != nil {
return err
}
resp, err := bot.client.Do(req)
if err != nil {
return errors.Wrap(err, getThisFuncName())
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.Wrap(errors.New("http status code: "+strconv.Itoa(resp.StatusCode)), getThisFuncName())
}
body := struct {
C int `json:"code"`
M string `json:"message"`
T string `json:"access_token"`
E string `json:"expires_in"`
}{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
return errors.Wrap(err, getThisFuncName())
}
if body.C != 0 {
return errors.Wrap(errors.New("code: "+strconv.Itoa(body.C)+", msg: "+body.M), getThisFuncName())
}
if body.T == "" {
return errors.Wrap(ErrEmptyToken, getThisFuncName())
}
if body.E == "" {
return errors.Wrap(ErrInvalidExpire, getThisFuncName())
}
bot.token = body.T
bot.expiresec, err = strconv.ParseInt(body.E, 10, 64)
if err != nil {
return errors.Wrap(err, getThisFuncName())
}
if bot.expiresec <= 0 {
return errors.Wrap(ErrInvalidExpire, getThisFuncName())
}
return nil
}