diff --git a/.github/Misaki.png b/.github/Misaki.png new file mode 100644 index 0000000..42af3dc Binary files /dev/null and b/.github/Misaki.png differ diff --git a/README.md b/README.md index 4931708..e1e3ab4 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,14 @@ -# ReiBot -Lightweight Telegram bot framework +
## Instructions @@ -8,7 +17,7 @@ This framework is a simple wrapper for [go-telegram-bot-api](https://github.com/ ## Quick Start > Here is a plugin-based example - + ```go package main @@ -45,7 +54,7 @@ func main() { > If Handler in Bot is implemented, the plugin function will be disabled. - + ```go package main diff --git a/bot.go b/bot.go index 69914d1..950a4ef 100644 --- a/bot.go +++ b/bot.go @@ -9,13 +9,15 @@ import ( type Bot struct { // Token bot 的 token // see https://core.telegram.org/bots#3-how-do-i-create-a-bot - Token string `json:"token"` + Token string // Buffer 控制消息队列的长度 - Buffer int `json:"buffer"` + Buffer int // UpdateConfig 配置消息获取 tgba.UpdateConfig + // SuperUsers 超级用户 + SuperUsers []int64 // Debug 控制调试信息的输出与否 - Debug bool `json:"debug"` + Debug bool // Handler 注册对各种事件的处理 Handler *Handler // handlers 方便调用的 handler diff --git a/context.go b/context.go index e304778..f135c7e 100644 --- a/context.go +++ b/context.go @@ -1,8 +1,22 @@ package rei +import tgba "github.com/go-telegram-bot-api/telegram-bot-api/v5" + type Ctx struct { Event State Caller *TelegramClient ma *Matcher } + +// CheckSession 判断会话连续性 +func (ctx *Ctx) CheckSession() Rule { + msg := ctx.Value.(*tgba.Message) + return func(ctx2 *Ctx) bool { + msg2, ok := ctx.Value.(*tgba.Message) + if !ok || msg.From == nil || msg.Chat == nil || msg2.From == nil || msg2.Chat == nil { // 确保无空 + return false + } + return msg.From.ID == msg2.From.ID && msg.Chat.ID == msg2.Chat.ID + } +} diff --git a/engine.go b/engine.go index 8ab1d8d..914bf98 100644 --- a/engine.go +++ b/engine.go @@ -159,12 +159,12 @@ func (e *Engine) OnChatJoinRequest(rules ...Rule) *Matcher { return e.On("ChatJo // OnChatJoinRequest ... func OnChatJoinRequest(rules ...Rule) *Matcher { return On("ChatJoinRequest", rules...) } -// OnPrefix 前缀触发器 +// OnMessagePrefix 前缀触发器 func OnMessagePrefix(prefix string, rules ...Rule) *Matcher { return defaultEngine.OnMessagePrefix(prefix, rules...) } -// OnPrefix 前缀触发器 +// OnMessagePrefix 前缀触发器 func (e *Engine) OnMessagePrefix(prefix string, rules ...Rule) *Matcher { matcher := &Matcher{ Type: "Message", @@ -174,3 +174,179 @@ func (e *Engine) OnMessagePrefix(prefix string, rules ...Rule) *Matcher { e.matchers = append(e.matchers, matcher) return StoreMatcher(matcher) } + +// OnMessageSuffix 后缀触发器 +func OnMessageSuffix(suffix string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageSuffix(suffix, rules...) +} + +// OnMessageSuffix 后缀触发器 +func (e *Engine) OnMessageSuffix(suffix string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{SuffixRule(suffix)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageCommand 命令触发器 +func OnMessageCommand(commands string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageCommand(commands, rules...) +} + +// OnMessageCommand 命令触发器 +func (e *Engine) OnMessageCommand(commands string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{CommandRule(commands)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageRegex 正则触发器 +func OnMessageRegex(regexPattern string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageRegex(regexPattern, rules...) +} + +// OnRegex 正则触发器 +func (e *Engine) OnMessageRegex(regexPattern string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{RegexRule(regexPattern)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageKeyword 关键词触发器 +func OnMessageKeyword(keyword string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageKeyword(keyword, rules...) +} + +// OnKeyword 关键词触发器 +func (e *Engine) OnMessageKeyword(keyword string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{KeywordRule(keyword)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageFullMatch 完全匹配触发器 +func OnMessageFullMatch(src string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageFullMatch(src, rules...) +} + +// OnMessageFullMatch 完全匹配触发器 +func (e *Engine) OnMessageFullMatch(src string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{FullMatchRule(src)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageFullMatchGroup 完全匹配触发器组 +func OnMessageFullMatchGroup(src []string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageFullMatchGroup(src, rules...) +} + +// OnMessageFullMatchGroup 完全匹配触发器组 +func (e *Engine) OnMessageFullMatchGroup(src []string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{FullMatchRule(src...)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageKeywordGroup 关键词触发器组 +func OnMessageKeywordGroup(keywords []string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageKeywordGroup(keywords, rules...) +} + +// OnMessageKeywordGroup 关键词触发器组 +func (e *Engine) OnMessageKeywordGroup(keywords []string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{KeywordRule(keywords...)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageCommandGroup 命令触发器组 +func OnMessageCommandGroup(commands []string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageCommandGroup(commands, rules...) +} + +// OnMessageCommandGroup 命令触发器组 +func (e *Engine) OnMessageCommandGroup(commands []string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{CommandRule(commands...)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessagePrefixGroup 前缀触发器组 +func OnMessagePrefixGroup(prefix []string, rules ...Rule) *Matcher { + return defaultEngine.OnMessagePrefixGroup(prefix, rules...) +} + +// OnMessagePrefixGroup 前缀触发器组 +func (e *Engine) OnMessagePrefixGroup(prefix []string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{PrefixRule(prefix...)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageSuffixGroup 后缀触发器组 +func OnMessageSuffixGroup(suffix []string, rules ...Rule) *Matcher { + return defaultEngine.OnMessageSuffixGroup(suffix, rules...) +} + +// OnMessageSuffixGroup 后缀触发器组 +func (e *Engine) OnMessageSuffixGroup(suffix []string, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{SuffixRule(suffix...)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} + +// OnMessageShell shell命令触发器 +func OnMessageShell(command string, model interface{}, rules ...Rule) *Matcher { + return defaultEngine.OnMessageShell(command, model, rules...) +} + +// OnMessageShell shell命令触发器 +func (e *Engine) OnMessageShell(command string, model interface{}, rules ...Rule) *Matcher { + matcher := &Matcher{ + Type: "Message", + Rules: append([]Rule{ShellRule(command, model)}, rules...), + Engine: e, + } + e.matchers = append(e.matchers, matcher) + return StoreMatcher(matcher) +} diff --git a/example/echo/main.go b/example/echo/main.go new file mode 100644 index 0000000..e668061 --- /dev/null +++ b/example/echo/main.go @@ -0,0 +1,18 @@ +package echo + +import ( + rei "github.com/fumiama/ReiBot" + tgba "github.com/go-telegram-bot-api/telegram-bot-api/v5" +) + +func init() { + rei.OnMessagePrefix("echo").SetBlock(true).SecondPriority(). + Handle(func(ctx *rei.Ctx) { + args := ctx.State["args"].(string) + if args == "" { + return + } + msg := ctx.Value.(*tgba.Message) + ctx.Caller.Send(tgba.NewMessage(msg.Chat.ID, args)) + }) +} diff --git a/example/main.go b/example/main.go index e0f205b..80b136b 100644 --- a/example/main.go +++ b/example/main.go @@ -1,19 +1,17 @@ package main import ( + _ "github.com/fumiama/ReiBot/example/echo" + rei "github.com/fumiama/ReiBot" tgba "github.com/go-telegram-bot-api/telegram-bot-api/v5" ) func main() { - rei.OnMessagePrefix("echo").SetBlock(true).SecondPriority(). + rei.OnMessageFullMatch("help").SetBlock(true).SecondPriority(). Handle(func(ctx *rei.Ctx) { - args := ctx.State["args"].(string) - if args == "" { - return - } msg := ctx.Value.(*tgba.Message) - ctx.Caller.Send(tgba.NewMessage(msg.Chat.ID, args)) + ctx.Caller.Send(tgba.NewMessage(msg.Chat.ID, "echo string")) }) rei.Run(rei.Bot{ Token: "", diff --git a/future.go b/future.go new file mode 100644 index 0000000..e23a316 --- /dev/null +++ b/future.go @@ -0,0 +1,98 @@ +package rei + +// FutureEvent 是 ZeroBot 交互式的核心,用于异步获取指定事件 +type FutureEvent struct { + Type string + Priority int + Rule []Rule + Block bool +} + +// NewFutureEvent 创建一个FutureEvent, 并返回其指针 +func NewFutureEvent(Type string, Priority int, Block bool, rule ...Rule) *FutureEvent { + return &FutureEvent{ + Type: Type, + Priority: Priority, + Rule: rule, + Block: Block, + } +} + +// FutureEvent 返回一个 FutureEvent 实例指针,用于获取满足 Rule 的 未来事件 +func (m *Matcher) FutureEvent(Type string, rule ...Rule) *FutureEvent { + return &FutureEvent{ + Type: Type, + Priority: m.Priority, + Block: m.Block, + Rule: rule, + } +} + +// Next 返回一个 chan 用于接收下一个指定事件 +// +// 该 chan 必须接收,如需手动取消监听,请使用 Repeat 方法 +func (n *FutureEvent) Next() <-chan *Ctx { + ch := make(chan *Ctx, 1) + StoreTempMatcher(&Matcher{ + Type: n.Type, + Block: n.Block, + Priority: n.Priority, + Rules: n.Rule, + Engine: defaultEngine, + Process: func(ctx *Ctx) { + ch <- ctx + close(ch) + }, + }) + return ch +} + +// Repeat 返回一个 chan 用于接收无穷个指定事件,和一个取消监听的函数 +// +// 如果没有取消监听,将不断监听指定事件 +func (n *FutureEvent) Repeat() (recv <-chan *Ctx, cancel func()) { + ch, done := make(chan *Ctx, 1), make(chan struct{}) + go func() { + defer close(ch) + in := make(chan *Ctx, 1) + matcher := StoreMatcher(&Matcher{ + Type: n.Type, + Block: n.Block, + Priority: n.Priority, + Rules: n.Rule, + Engine: defaultEngine, + Process: func(ctx *Ctx) { + in <- ctx + }, + }) + for { + select { + case e := <-in: + ch <- e + case <-done: + matcher.Delete() + close(in) + return + } + } + }() + return ch, func() { + close(done) + } +} + +// Take 基于 Repeat 封装,返回一个 chan 接收指定数量的事件 +// +// 该 chan 对象必须接收,否则将有 goroutine 泄漏,如需手动取消请使用 Repeat +func (n *FutureEvent) Take(num int) <-chan *Ctx { + recv, cancel := n.Repeat() + ch := make(chan *Ctx, num) + go func() { + defer close(ch) + for i := 0; i < num; i++ { + ch <- <-recv + } + cancel() + }() + return ch +} diff --git a/go.mod b/go.mod index 9c5171d..93e222a 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,13 @@ module github.com/fumiama/ReiBot go 1.18 -require github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 +require ( + github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 + github.com/stretchr/testify v1.7.1 +) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/go.sum b/go.sum index db8e45c..2056981 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,13 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc= github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rules.go b/rules.go index 6964c75..dc1fad5 100644 --- a/rules.go +++ b/rules.go @@ -1,12 +1,15 @@ package rei import ( + "reflect" + "regexp" "strings" + "time" tgba "github.com/go-telegram-bot-api/telegram-bot-api/v5" ) -// PrefixRule check if the message has the prefix and trim the prefix +// PrefixRule check if the text message has the prefix and trim the prefix // // 检查消息前缀 func PrefixRule(prefixes ...string) Rule { @@ -26,3 +29,274 @@ func PrefixRule(prefixes ...string) Rule { return false } } + +// SuffixRule check if the text message has the suffix and trim the suffix +// +// 检查消息后缀 +func SuffixRule(suffixes ...string) Rule { + return func(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.Text == "" { // 确保无空 + return false + } + for _, suffix := range suffixes { + if strings.HasSuffix(msg.Text, suffix) { + ctx.State["suffix"] = suffix + arg := strings.TrimRight(msg.Text[:len(msg.Text)-len(suffix)], " ") + ctx.State["args"] = arg + return true + } + } + return false + } +} + +// CommandRule check if the message is a command and trim the command name +func CommandRule(commands ...string) Rule { + return func(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.Text == "" || !msg.IsCommand() { // 确保无空 + return false + } + ctx.State["command"] = msg.CommandWithAt() + ctx.State["args"] = msg.CommandArguments() + return true + } +} + +// RegexRule check if the message can be matched by the regex pattern +func RegexRule(regexPattern string) Rule { + regex := regexp.MustCompile(regexPattern) + return func(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.Text == "" { // 确保无空 + return false + } + if matched := regex.FindStringSubmatch(msg.Text); matched != nil { + ctx.State["regex_matched"] = matched + return true + } + return false + } +} + +// ReplyRule check if the message is replying some message +func ReplyRule(messageID int) Rule { + return func(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.ReplyToMessage == nil { // 确保无空 + return false + } + return messageID == msg.MessageID + } +} + +// KeywordRule check if the message has a keyword or keywords +func KeywordRule(src ...string) Rule { + return func(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.Text == "" { // 确保无空 + return false + } + for _, str := range src { + if strings.Contains(msg.Text, str) { + ctx.State["keyword"] = str + return true + } + } + return false + } +} + +// FullMatchRule check if src has the same copy of the message +func FullMatchRule(src ...string) Rule { + return func(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.Text == "" { // 确保无空 + return false + } + for _, str := range src { + if str == msg.Text { + ctx.State["matched"] = msg.Text + return true + } + } + return false + } +} + +// ShellRule 定义shell-like规则 +func ShellRule(cmd string, model interface{}) Rule { + cmdRule := CommandRule(cmd) + t := reflect.TypeOf(model) + return func(ctx *Ctx) bool { + if !cmdRule(ctx) { + return false + } + // bind flag to struct + args := ParseShell(ctx.State["args"].(string)) + val := reflect.New(t) + fs := registerFlag(t, val) + err := fs.Parse(args) + if err != nil { + return false + } + ctx.State["args"] = fs.Args() + ctx.State["flag"] = val.Interface() + return true + } +} + +// OnlyToMe only triggered in conditions of @bot or begin with the nicknames +func OnlyToMe(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.Text == "" { // 确保无空 + return false + } + name := ctx.Caller.Self.String() + if strings.HasPrefix(msg.Text, name) { + return true + } + n := 0 + for _, e := range msg.Entities { + if e.IsMention() && e.Length > 0 && msg.Text[n+1:n+e.Length] == name { + return true + } + n += e.Length + } + return false +} + +// CheckUser only triggered by specific person +func CheckUser(userId ...int64) Rule { + return func(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.From == nil { // 确保无空 + return false + } + for _, uid := range userId { + if msg.From.ID == uid { + return true + } + } + return false + } +} + +// CheckChat only triggered in specific chat +func CheckChat(chatId ...int64) Rule { + return func(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.Chat == nil { // 确保无空 + return false + } + for _, cid := range chatId { + if msg.Chat.ID == cid { + return true + } + } + return false + } +} + +// SuperUserPermission only triggered by the bot's owner +func SuperUserPermission(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.From == nil { // 确保无空 + return false + } + for _, su := range ctx.Caller.b.SuperUsers { + if su == msg.From.ID { + return true + } + } + return false +} + +// CreaterPermission only triggered by the group creater or higher permission +func CreaterPermission(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.From == nil || msg.Chat == nil { // 确保无空 + return false + } + for _, su := range ctx.Caller.b.SuperUsers { + if su == msg.From.ID { + return true + } + } + m, err := ctx.Caller.GetChatMember( + tgba.GetChatMemberConfig{ + ChatConfigWithUser: tgba.ChatConfigWithUser{ + ChatID: msg.Chat.ID, + UserID: msg.From.ID, + }, + }, + ) + if err != nil { + return false + } + return m.IsCreator() +} + +// AdminPermission only triggered by the group admins or higher permission +func AdminPermission(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || msg.From == nil || msg.Chat == nil { // 确保无空 + return false + } + for _, su := range ctx.Caller.b.SuperUsers { + if su == msg.From.ID { + return true + } + } + m, err := ctx.Caller.GetChatMember( + tgba.GetChatMemberConfig{ + ChatConfigWithUser: tgba.ChatConfigWithUser{ + ChatID: msg.Chat.ID, + UserID: msg.From.ID, + }, + }, + ) + if err != nil { + return false + } + return m.IsCreator() || m.IsAdministrator() +} + +// IsPhoto 消息是图片返回 true +func IsPhoto(ctx *Ctx) bool { + msg, ok := ctx.Value.(*tgba.Message) + if !ok || len(msg.Photo) == 0 { // 确保无空 + return false + } + ctx.State["photos"] = msg.Photo + return true +} + +// MustProvidePhoto 消息不存在图片阻塞120秒至有图片,超时返回 false +func MustProvidePhoto(ctx *Ctx, needphohint, failhint string) bool { + msg, ok := ctx.Value.(*tgba.Message) + if ok && len(msg.Photo) > 0 { // 确保无空 + ctx.State["photos"] = msg.Photo + return true + } + // 没有图片就索取 + if needphohint != "" { + _, err := ctx.Caller.Send(tgba.NewMessage(msg.Chat.ID, needphohint)) + if err != nil { + return false + } + } + next := NewFutureEvent("message", 999, false, ctx.CheckSession(), IsPhoto).Next() + select { + case <-time.After(time.Second * 120): + if failhint != "" { + _, _ = ctx.Caller.Send(tgba.NewMessage(msg.Chat.ID, failhint)) + } + return false + case newCtx := <-next: + ctx.State["photos"] = newCtx.State["photos"] + ctx.Event = newCtx.Event + return true + } +} diff --git a/shell.go b/shell.go new file mode 100644 index 0000000..915f355 --- /dev/null +++ b/shell.go @@ -0,0 +1,131 @@ +package rei + +import ( + "flag" + "reflect" + "strings" +) + +func isSpace(r rune) bool { + switch r { + case ' ', '\t', '\r', '\n': + return true + } + return false +} + +type argType int + +const ( + argNo argType = iota + argSingle + argQuoted +) + +// ParseShell 将指令转换为指令参数. +// modified from https://github.com/mattn/go-shellwords +func ParseShell(s string) []string { + var args []string + buf := strings.Builder{} + var escaped, doubleQuoted, singleQuoted, backQuote bool + backtick := "" + + got := argNo + + for _, r := range s { + if escaped { + buf.WriteRune(r) + escaped = false + got = argSingle + continue + } + + if r == '\\' { + if singleQuoted { + buf.WriteRune(r) + } else { + escaped = true + } + continue + } + + if isSpace(r) { + if singleQuoted || doubleQuoted || backQuote { + buf.WriteRune(r) + backtick += string(r) + } else if got != argNo { + args = append(args, buf.String()) + buf.Reset() + got = argNo + } + continue + } + + switch r { + case '`': + if !singleQuoted && !doubleQuoted { + backtick = "" + backQuote = !backQuote + } + case '"': + if !singleQuoted { + if doubleQuoted { + got = argQuoted + } + doubleQuoted = !doubleQuoted + } + case '\'': + if !doubleQuoted { + if singleQuoted { + got = argSingle + } + singleQuoted = !singleQuoted + } + default: + got = argSingle + buf.WriteRune(r) + if backQuote { + backtick += string(r) + } + } + } + + if got != argNo { + args = append(args, buf.String()) + } + + return args +} + +var ( + boolType = reflect.TypeOf(false) + intType = reflect.TypeOf(0) + stringType = reflect.TypeOf("") + float64Type = reflect.TypeOf(float64(0)) +) + +func registerFlag(t reflect.Type, v reflect.Value) *flag.FlagSet { + v = v.Elem() + fs := flag.NewFlagSet("", flag.ContinueOnError) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + name := field.Tag.Get("flag") + if name == "" { + continue + } + help := field.Tag.Get("help") + switch field.Type { + case boolType: + fs.BoolVar(v.Field(i).Addr().Interface().(*bool), name, false, help) + case intType: + fs.IntVar(v.Field(i).Addr().Interface().(*int), name, 0, help) + case stringType: + fs.StringVar(v.Field(i).Addr().Interface().(*string), name, "", help) + case float64Type: + fs.Float64Var(v.Field(i).Addr().Interface().(*float64), name, 0, help) + default: + panic("unsupported type") + } + } + return fs +} diff --git a/shell_test.go b/shell_test.go new file mode 100644 index 0000000..3e5ac72 --- /dev/null +++ b/shell_test.go @@ -0,0 +1,45 @@ +package rei + +import ( + "reflect" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_parse(t *testing.T) { + shellTests := [...]struct { + shell string + expected []string + }{ + {`rm -rf /*`, []string{"rm", "-rf", "/*"}}, + {`echo "cat cat" -n`, []string{"echo", "cat cat", "-n"}}, + {`shutdown halt init`, []string{"shutdown", "halt", "init"}}, + {`test test2`, []string{"test", "test2"}}, + } + for i, v := range shellTests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + out := ParseShell(v.shell) + assert.Equal(t, v.expected, out) + }) + } +} + +func Test_registerFlag(t *testing.T) { + type args struct { + RF bool `flag:"rf"` + File string `flag:"file"` + Count int `flag:"count"` + } + got := args{} + expected := args{ + RF: true, + File: "123", + Count: 10, + } + fs := registerFlag(reflect.TypeOf(args{}), reflect.ValueOf(&got)) + err := fs.Parse([]string{"-rf", "-file=123", "-count", "10"}) + assert.NoError(t, err) + assert.Equal(t, expected, got) +}