From 57b0a7c52c078d02f4ae047588001d71a61ed883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sun, 12 Nov 2023 15:56:40 +0900 Subject: [PATCH] feat: add v2 login method --- bot.go | 118 ++++++++++++++++++++++++++++++++++--------------- http.go | 33 ++++++++------ openapi.go | 8 +++- openapi_wss.go | 62 ++++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 50 deletions(-) diff --git a/bot.go b/bot.go index 29d409f..661c14c 100644 --- a/bot.go +++ b/bot.go @@ -25,9 +25,10 @@ var ( // Bot 一个机器人实例的配置 type Bot struct { - AppID string `yaml:"AppID"` // AppID is BotAppID(开发者ID) - Token string `yaml:"Token"` // Token is 机器人令牌 - Secret string `yaml:"Secret"` // Secret is 机器人密钥 + AppID string `yaml:"AppID"` // AppID is BotAppID(开发者ID) + Token string `yaml:"Token"` // Token is 机器人令牌 有 Secret 则使用新版 API + token string // token 是通过 secret 获得的残血 token + Secret string `yaml:"Secret"` // Secret is 机器人令牌 V2 (AppSecret/ClientSecret) 沙盒目前虽然能登录但无法收发消息 SuperUsers []string `yaml:"SuperUsers"` // SuperUsers 超级用户 Timeout time.Duration `yaml:"Timeout"` // Timeout is API 调用超时 Handler *Handler `yaml:"-"` // Handler 注册对各种事件的处理 @@ -40,10 +41,12 @@ type Bot struct { gateway string // gateway 获得的网关 seq uint32 // seq 最新的 s heartbeat uint32 // heartbeat 心跳周期, 单位毫秒 + expiresec int64 // expiresec Token 有效时间 handlers map[string]eventHandlerType // handlers 方便调用的 handler mu sync.Mutex // 写锁 conn *websocket.Conn // conn 目前的 wss 连接 hbonce sync.Once // hbonce 保证仅执行一次 heartbeat + exonce sync.Once // exonce 保证仅执行一次刷新 token client *http.Client // client 主要配置 timeout ready EventReady // ready 连接成功后下发的 bot 基本信息 @@ -55,33 +58,37 @@ func (ctx *Ctx) GetReady() *EventReady { } // getinitinfo 获得 gateway 和 shard -func (b *Bot) getinitinfo() (gw string, shard [2]byte, err error) { +func (bot *Bot) getinitinfo() (secret, gw string, shard [2]byte, err error) { shard[1] = 1 - if b.client == nil { - b.client = http.DefaultClient + if bot.client == nil { + bot.client = http.DefaultClient } - if b.ShardIndex == 0 { - gw, err = b.GetGeneralWSSGateway() + secret = bot.Secret + if bot.Secret != "" { + bot.Secret = "" + } + if bot.ShardIndex == 0 { + gw, err = bot.GetGeneralWSSGateway() if err != nil { return } } else { var sgw *ShardWSSGateway - sgw, err = b.GetShardWSSGateway() + sgw, err = bot.GetShardWSSGateway() if err != nil { return } - if b.ShardCount == 0 { + if bot.ShardCount == 0 { log.Infoln(getLogHeader(), "使用网关推荐Shards数:", sgw.Shards) - b.ShardCount = uint8(sgw.Shards) + bot.ShardCount = uint8(sgw.Shards) } - if b.ShardCount <= b.ShardIndex { - err = errors.New("shard index " + strconv.Itoa(int(b.ShardIndex)) + " >= suggested size " + strconv.Itoa(sgw.Shards)) + if bot.ShardCount <= bot.ShardIndex { + err = errors.New("shard index " + strconv.Itoa(int(bot.ShardIndex)) + " >= suggested size " + strconv.Itoa(sgw.Shards)) return } gw = sgw.URL - shard[0] = byte(b.ShardIndex) - shard[1] = byte(b.ShardCount) + shard[0] = byte(bot.ShardIndex) + shard[1] = byte(bot.ShardCount) } return } @@ -92,11 +99,11 @@ func Start(bots ...*Bot) error { log.Warnln(getLogHeader(), "已忽略重复调用的", getThisFuncName()) } for _, b := range bots { - gw, shard, err := b.getinitinfo() + s, gw, shard, err := b.getinitinfo() if err != nil { return err } - go b.Init(gw, shard).Connect().Listen() + go b.Init(s, gw, shard).Connect().Listen() } return nil } @@ -112,25 +119,25 @@ func Run(preblock func(), bots ...*Bot) error { return nil case 1: b = bots[0] - gw, shard, err := b.getinitinfo() + s, gw, shard, err := b.getinitinfo() if err != nil { return err } - b.Init(gw, shard) + b.Init(s, gw, shard) default: for _, b := range bots[:len(bots)-1] { - gw, shard, err := b.getinitinfo() + s, gw, shard, err := b.getinitinfo() if err != nil { return err } - go b.Init(gw, shard).Connect().Listen() + go b.Init(s, gw, shard).Connect().Listen() } b = bots[len(bots)-1] - gw, shard, err := b.getinitinfo() + s, gw, shard, err := b.getinitinfo() if err != nil { return err } - b.Init(gw, shard) + b.Init(s, gw, shard) } b.Connect() if preblock != nil { @@ -141,19 +148,19 @@ func Run(preblock func(), bots ...*Bot) error { } // Init 初始化, 只需执行一次 -func (b *Bot) Init(gateway string, shard [2]byte) *Bot { - b.gateway = gateway - b.shard = shard - if b.Timeout == 0 { - b.Timeout = time.Minute +func (bot *Bot) Init(secret, gateway string, shard [2]byte) *Bot { + bot.gateway = gateway + bot.shard = shard + if bot.Timeout == 0 { + bot.Timeout = time.Minute } - b.client = &http.Client{ - Timeout: b.Timeout, + bot.client = &http.Client{ + Timeout: bot.Timeout, } - if b.Handler != nil { - h := reflect.ValueOf(b.Handler).Elem() + if bot.Handler != nil { + h := reflect.ValueOf(bot.Handler).Elem() t := h.Type() - b.handlers = make(map[string]eventHandlerType, h.NumField()*4) + bot.handlers = make(map[string]eventHandlerType, h.NumField()*4) for i := 0; i < h.NumField(); i++ { f := h.Field(i) if f.IsZero() { @@ -162,17 +169,40 @@ func (b *Bot) Init(gateway string, shard [2]byte) *Bot { tp := t.Field(i).Name[2:] // skip On log.Infoln(getLogHeader(), "注册处理函数", tp) handler := f.Interface() - b.handlers[tp] = eventHandlerType{ + bot.handlers[tp] = eventHandlerType{ h: *(*generalHandleType)(unsafe.Add(unsafe.Pointer(&handler), unsafe.Sizeof(uintptr(0)))), t: t.Field(i).Type.In(2).Elem(), } } } - return b + bot.Secret = secret + if bot.IsV2() { + for { + err := bot.GetAppAccessToken() + if err == nil { + log.Infoln(getLogHeader(), "获得 Token: "+bot.token+", 超时:", bot.expiresec, "秒") + bot.exonce.Do(func() { + go bot.refreshtoken() + }) + break + } + log.Infoln(getLogHeader(), "获得 Token 失败:", err) + time.Sleep(time.Second * 3) + } + } + return bot +} + +// IsV2 判断是否运行于 V2 API 下 +func (bot *Bot) IsV2() bool { + return bot.Secret != "" } // Authorization 返回 Authorization Header value func (bot *Bot) Authorization() string { + if bot.IsV2() { + return "QQBot " + bot.token + } return "Bot " + bot.AppID + "." + bot.Token } @@ -271,6 +301,24 @@ func (bot *Bot) Connect() *Bot { return bot } +// refreshtoken 以 Expire 为间隔刷新 Token +func (bot *Bot) refreshtoken() { + for { + time.Sleep(time.Second * 10) + if atomic.LoadUint32(&bot.heartbeat) == 0 { + log.Warnln(getLogHeader(), "等待服务器建立连接...") + continue + } + time.Sleep(time.Duration(bot.expiresec) * time.Second) + err := bot.GetAppAccessToken() + if err != nil { + log.Warnln(getLogHeader(), "刷新 Token 时出现错误:", err) + } else { + log.Infoln(getLogHeader(), "刷新 Token: "+bot.token+", 超时:", bot.expiresec, "秒") + } + } +} + // doheartbeat 按指定间隔进行心跳包发送 func (bot *Bot) doheartbeat() { payload := struct { diff --git a/http.go b/http.go index 166167c..c2af159 100644 --- a/http.go +++ b/http.go @@ -19,14 +19,19 @@ import ( ) // HTTPRequsetConstructer ... -type HTTPRequsetConstructer func(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) +type HTTPRequsetConstructer func(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) -func newHTTPEndpointRequestWithAuth(method, contenttype, ep string, auth string, body io.Reader) (req *http.Request, err error) { - req, err = http.NewRequest(method, OpenAPI+ep, body) +func newHTTPEndpointRequestWithAuth(method, contenttype, ep string, auth, appid string, body io.Reader) (req *http.Request, err error) { + req, err = http.NewRequest(method, ep, body) if err != nil { return } - req.Header.Add("Authorization", auth) + if auth != "" { + req.Header.Add("Authorization", auth) + } + if appid != "" { + req.Header.Add("X-Union-Appid", appid) + } if contenttype == "" { contenttype = "application/json" } @@ -35,28 +40,28 @@ func newHTTPEndpointRequestWithAuth(method, contenttype, ep string, auth string, } // NewHTTPEndpointGetRequestWithAuth 新建带鉴权头的 HTTP GET 请求 -func NewHTTPEndpointGetRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { - return newHTTPEndpointRequestWithAuth("GET", contenttype, ep, auth, body) +func NewHTTPEndpointGetRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) { + return newHTTPEndpointRequestWithAuth("GET", contenttype, OpenAPI+ep, auth, appid, body) } // NewHTTPEndpointPutRequestWithAuth 新建带鉴权头的 HTTP PUT 请求 -func NewHTTPEndpointPutRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { - return newHTTPEndpointRequestWithAuth("PUT", contenttype, ep, auth, body) +func NewHTTPEndpointPutRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) { + return newHTTPEndpointRequestWithAuth("PUT", contenttype, OpenAPI+ep, auth, appid, body) } // NewHTTPEndpointDeleteRequestWithAuth 新建带鉴权头的 HTTP DELETE 请求 -func NewHTTPEndpointDeleteRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { - return newHTTPEndpointRequestWithAuth("DELETE", contenttype, ep, auth, body) +func NewHTTPEndpointDeleteRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) { + return newHTTPEndpointRequestWithAuth("DELETE", contenttype, OpenAPI+ep, auth, appid, body) } // NewHTTPEndpointPostRequestWithAuth 新建带鉴权头的 HTTP POST 请求 -func NewHTTPEndpointPostRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { - return newHTTPEndpointRequestWithAuth("POST", contenttype, ep, auth, body) +func NewHTTPEndpointPostRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) { + return newHTTPEndpointRequestWithAuth("POST", contenttype, OpenAPI+ep, auth, appid, body) } // NewHTTPEndpointPatchRequestWithAuth 新建带鉴权头的 HTTP PATCH 请求 -func NewHTTPEndpointPatchRequestWithAuth(ep string, contenttype string, auth string, body io.Reader) (*http.Request, error) { - return newHTTPEndpointRequestWithAuth("PATCH", contenttype, ep, auth, body) +func NewHTTPEndpointPatchRequestWithAuth(ep string, contenttype string, auth, appid string, body io.Reader) (*http.Request, error) { + return newHTTPEndpointRequestWithAuth("PATCH", contenttype, OpenAPI+ep, auth, appid, body) } // WriteHTTPQueryIfNotNil 如果非空则将请求添加到 baseurl 后 diff --git a/openapi.go b/openapi.go index 1322114..ff79235 100644 --- a/openapi.go +++ b/openapi.go @@ -17,6 +17,8 @@ const ( StandardAPI = `https://api.sgroup.qq.com` // SandboxAPI 沙箱环境接口域名 SandboxAPI = `https://sandbox.api.sgroup.qq.com` + // AccessTokenAPI 获取接口凭证的 API + AccessTokenAPI = "https://bots.qq.com/app/getAppAccessToken" ) var ( @@ -30,7 +32,11 @@ type CodeMessageBase struct { } func (bot *Bot) dohttprequest(constructer HTTPRequsetConstructer, ep, contenttype string, ptr any, body io.Reader) error { - req, err := constructer(ep, contenttype, bot.Authorization(), body) + appid := "" + if bot.IsV2() { + appid = bot.AppID + } + req, err := constructer(ep, contenttype, bot.Authorization(), appid, body) if err != nil { return errors.Wrap(err, getCallerFuncName()) } diff --git a/openapi_wss.go b/openapi_wss.go index 1823257..02a035d 100644 --- a/openapi_wss.go +++ b/openapi_wss.go @@ -1,5 +1,18 @@ package nano +import ( + "encoding/json" + "net/http" + "strconv" + + "github.com/pkg/errors" +) + +var ( + ErrEmptyToken = errors.New("empty token") + ErrInvalidExpire = errors.New("invalid expire") +) + // GetGeneralWSSGateway 获取通用 WSS 接入点 // // https://bot.q.qq.com/wiki/develop/api/openapi/wss/url_get.html @@ -32,3 +45,52 @@ type ShardWSSGateway struct { func (bot *Bot) GetShardWSSGateway() (*ShardWSSGateway, error) { return bot.getOpenAPIofShardWSSGateway("/gateway/bot") } + +// GetAppAccessToken 获取接口凭证并保存到 bot.Token +// +// https://bot.q.qq.com/wiki/develop/api-231017/dev-prepare/interface-framework/api-use.html#%E8%8E%B7%E5%8F%96%E6%8E%A5%E5%8F%A3%E5%87%AD%E8%AF%81 +func (bot *Bot) GetAppAccessToken() error { + req, err := newHTTPEndpointRequestWithAuth("POST", "", AccessTokenAPI, "", "", WriteBodyFromJSON(&struct { + A string `json:"appId"` + S string `json:"clientSecret"` + }{bot.AppID, bot.Secret})) + if err != nil { + return err + } + resp, err := bot.client.Do(req) + if err != nil { + return errors.Wrap(err, getThisFuncName()) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return errors.Wrap(errors.New("http status code: "+strconv.Itoa(resp.StatusCode)), getThisFuncName()) + } + body := struct { + C int `json:"code"` + M string `json:"message"` + T string `json:"access_token"` + E string `json:"expires_in"` + }{} + err = json.NewDecoder(resp.Body).Decode(&body) + if err != nil { + return errors.Wrap(err, getThisFuncName()) + } + if body.C != 0 { + return errors.Wrap(errors.New("code: "+strconv.Itoa(body.C)+", msg: "+body.M), getThisFuncName()) + } + if body.T == "" { + return errors.Wrap(ErrEmptyToken, getThisFuncName()) + } + if body.E == "" { + return errors.Wrap(ErrInvalidExpire, getThisFuncName()) + } + bot.token = body.T + bot.expiresec, err = strconv.ParseInt(body.E, 10, 64) + if err != nil { + return errors.Wrap(err, getThisFuncName()) + } + if bot.expiresec <= 0 { + return errors.Wrap(ErrInvalidExpire, getThisFuncName()) + } + return nil +}