From 32fdf3ae9070958d7e9fd2222a86b95aa3e6a2a7 Mon Sep 17 00:00:00 2001 From: chasey-dev Date: Wed, 21 May 2025 16:09:57 +0800 Subject: [PATCH] feat: auto select server IP when not specified (#4) --- README.md | 18 +++++++++++---- cmd/main.go | 22 ++++++++---------- portal/portal.go | 54 ++++++++++++++++++++++++++++++------------- portal/server_test.go | 27 ++++++++++++++-------- 4 files changed, 79 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index d9ebf4e..3303bd4 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,23 @@ $ go install github.com/fumiama/go-nd-portal@latest > 也可不带参数运行,会在启动时询问参数 -```bash -./go-nd-portal -n 20xxxxxxxxxxx -p password [-x] +``` +./go-nd-portal -n 20xxxxxxxxxxx -p password [-t ] ``` 默认值: * `-ip`: 本机公网出口,可自定义 - * `-x`: 是否使用电信登陆(否) + + * `-t`: 登录类型(`qsh-edu`),可指定为: + * 清水河,教学办公区: + * `qsh-edu`, 教育网 + * `qsh-dx`, 电信 + * 清水河,新建宿舍区: + * `qshd-dx`, 电信 + * `qshd-cmcc`, 移动 + + * `-s`: 服务器地址(根据上述登录类型自动选择),可自定义 + ## 效果 -screenshot +screenshot diff --git a/cmd/main.go b/cmd/main.go index 51a9655..4bfe27b 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -51,7 +51,7 @@ func Main() { h := flag.Bool("h", false, "display this help") w := flag.Bool("w", false, "only display warn-or-higher-level log") d := flag.Bool("d", false, "display debug-level log") - s := flag.String("s", portal.PortalServerIPQsh, "login host") + s := flag.String("s", "", "login host, auto select when empty") t := flag.String("t", "qsh-edu", "login type [qsh-edu | qsh-dx | qshd-dx | qshd-cmcc]") flag.Parse() if *h { @@ -99,8 +99,7 @@ func Main() { *p = helper.BytesToString(data) fmt.Println() } - logrus.Debugf("server addr: %s, login type: %s", *s, *t) - if *s != portal.PortalServerIPQsh { + if *s != "" { // just validate IP here, // dont convert to net.IP because we need only its string later _, err := netip.ParseAddr(*s) @@ -113,22 +112,19 @@ func Main() { // p: password // ip : public ip // *t : login type - ptl, err := portal.NewPortal(*n, *p, ip, portal.LoginType(*t)) + ptl, err := portal.NewPortal(*n, *p, *s, ip, portal.LoginType(*t)) + if err != nil { + logrus.Errorln(err) + os.Exit(line()) + } + challenge, err := ptl.GetChallenge() if err != nil { logrus.Errorln(err) os.Exit(line()) } // input: - // server IP - challenge, err := ptl.GetChallenge(*s) - if err != nil { - logrus.Errorln(err) - os.Exit(line()) - } - // input: - // server IP // challenge - err = ptl.Login(*s, challenge) + err = ptl.Login(challenge) if err != nil { logrus.Errorln(err) os.Exit(line()) diff --git a/portal/portal.go b/portal/portal.go index ddbfe4a..20a201e 100644 --- a/portal/portal.go +++ b/portal/portal.go @@ -30,7 +30,8 @@ var ( type Portal struct { name string pswd string - ip net.IP + cip net.IP + sip string domain string acid string } @@ -49,6 +50,21 @@ const ( LoginTypeQshDormCMCC LoginType = "qshd-cmcc" ) +// GetDefaultPortalServerIP returns default PortalServerIP by LoginType +func (lt LoginType) GetDefaultPortalServerIP() (string, error) { + var sIP string + switch lt { + case LoginTypeQshEdu, LoginTypeQshDX: + sIP = PortalServerIPQsh + case LoginTypeQshDormDX, LoginTypeQshDormCMCC: + sIP = PortalServerIPQshDorm + default: + return "", ErrIllegalLoginType + } + + return sIP, nil +} + // ToDomainAcID converts LoginType to domain and acid func (lt LoginType) ToDomainAcID() (string, string, error) { var domain, acid string @@ -80,8 +96,8 @@ type rsp struct { } // NewPortal creates a new Portal instance -func NewPortal(name, password string, ipv4 net.IP, loginType LoginType) (*Portal, error) { - if len(ipv4) != 4 { +func NewPortal(name, password, sIP string, cIP net.IP, loginType LoginType) (*Portal, error) { + if len(cIP) != 4 { return nil, ErrIllegalIPv4 } @@ -89,28 +105,35 @@ func NewPortal(name, password string, ipv4 net.IP, loginType LoginType) (*Portal if err != nil { return nil, err } - logrus.Debugf("portal domain: %s, ac_id: %s", domain, acid) + logrus.Debugf("login type: %s, portal domain: %s, ac_id: %s", loginType, domain, acid) + + if sIP == "" { + sIP, err = loginType.GetDefaultPortalServerIP() + if err != nil { + return nil, err + } + } + logrus.Debugf("server addr: %s", sIP) return &Portal{ name: name, pswd: password, - ip: ipv4, + cip: cIP, + sip: sIP, domain: domain, acid: acid, }, nil } // GetChallenge gets token for encryption from server -// input: -// server IP -func (p *Portal) GetChallenge(sIP string) (string, error) { +func (p *Portal) GetChallenge() (string, error) { // Note: no need to do URL encoding here u, err := GetChallengeURL( - sIP, + p.sip, "gondportal", p.name, p.domain, - p.ip, + p.cip, time.Now().UnixMilli(), ) @@ -148,10 +171,9 @@ func (p *Portal) PasswordHMd5(challenge string) string { // Login sends login request to server // input: -// server IP // challenge -func (p *Portal) Login(sIP, challenge string) error { - userInfo, err := GetUserInfo(p.name, p.domain, p.pswd, p.ip, p.acid) +func (p *Portal) Login(challenge string) error { + userInfo, err := GetUserInfo(p.name, p.domain, p.pswd, p.cip, p.acid) if err != nil { return err } @@ -159,14 +181,14 @@ func (p *Portal) Login(sIP, challenge string) error { hmd5 := p.PasswordHMd5(challenge) // Note: no need to do URL encoding here u, err := GetLoginURL( - sIP, + p.sip, "gondportal", p.name, p.domain, hmd5, p.acid, - p.ip, - p.CheckSum(challenge, p.name, p.domain, hmd5, p.acid, p.ip, info), + p.cip, + p.CheckSum(challenge, p.name, p.domain, hmd5, p.acid, p.cip, info), info, time.Now().UnixMilli(), ) diff --git a/portal/server_test.go b/portal/server_test.go index 147e415..0b4c82e 100644 --- a/portal/server_test.go +++ b/portal/server_test.go @@ -12,12 +12,21 @@ import ( "github.com/fumiama/go-nd-portal/helper" ) -func TestGetUserInfo(t *testing.T) { - u, err := NewPortal("2000010101001", "12345678", net.IPv4(1, 2, 3, 4).To4(),"qsh-edu") +func TestAutoSelectServerIP(t *testing.T) { + u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu) if err != nil { t.Fatal(err) } - info, err := GetUserInfo(u.name, u.domain, u.pswd, u.ip, u.acid) + t.Log(u.sip) + assert.Equal(t, PortalServerIPQsh, u.sip) +} + +func TestGetUserInfo(t *testing.T) { + u, err := NewPortal("2000010101001", "12345678", "", net.IPv4(1, 2, 3, 4).To4(), LoginTypeQshEdu) + if err != nil { + t.Fatal(err) + } + info, err := GetUserInfo(u.name, u.domain, u.pswd, u.cip, u.acid) if err != nil { t.Fatal(err) } @@ -53,11 +62,11 @@ func TestDecodeKey(t *testing.T) { } func TestEncodeUserInfo(t *testing.T) { - u, err := NewPortal("2001010101001", "1234567890", net.IPv4(113, 54, 148, 243).To4(),"qsh-edu") + u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu) if err != nil { t.Fatal(err) } - info, err := GetUserInfo(u.name, u.domain, u.pswd, u.ip, u.acid) + info, err := GetUserInfo(u.name, u.domain, u.pswd, u.cip, u.acid) if err != nil { t.Fatal(err) } @@ -67,7 +76,7 @@ func TestEncodeUserInfo(t *testing.T) { } func TestHMd5(t *testing.T) { - u, err := NewPortal("2001010101001", "1234567890", net.IPv4(113, 54, 148, 243).To4(),"qsh-edu") + u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu) if err != nil { t.Fatal(err) } @@ -81,11 +90,11 @@ func TestSha1(t *testing.T) { } func TestCheckSum(t *testing.T) { - u, err := NewPortal("2001010101001", "1234567890", net.IPv4(113, 54, 148, 243).To4(),"qsh-edu") + u, err := NewPortal("2001010101001", "1234567890", "", net.IPv4(113, 54, 148, 243).To4(), LoginTypeQshEdu) if err != nil { t.Fatal(err) } - info, err := GetUserInfo(u.name, u.domain, u.pswd, u.ip, u.acid) + info, err := GetUserInfo(u.name, u.domain, u.pswd, u.cip, u.acid) if err != nil { t.Fatal(err) } @@ -97,7 +106,7 @@ func TestCheckSum(t *testing.T) { PortalDomainQsh, u.PasswordHMd5(challenge), u.acid, - u.ip, + u.cip, EncodeUserInfo( info, challenge,